mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
7
.flake8
Normal file
7
.flake8
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[flake8]
|
||||||
|
ignore = E203, E402, E501, E731, E741, W503, W605, E722
|
||||||
|
max-line-length = 119
|
||||||
|
|
||||||
|
# E402: module level import not at top of file
|
||||||
|
per-file-ignores =
|
||||||
|
__init__.py:F401,F403,E402
|
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -2,7 +2,7 @@ name: CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- develop
|
- develop
|
||||||
- 'release/*'
|
- 'release/*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@@ -86,4 +86,4 @@ jobs:
|
|||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
bash scripts/run_ci.sh
|
bash scripts/run_ci.sh
|
||||||
"
|
"
|
||||||
|
6
.github/workflows/ci_xpu.yml
vendored
6
.github/workflows/ci_xpu.yml
vendored
@@ -2,7 +2,7 @@ name: CI_XPU
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- develop
|
- develop
|
||||||
- 'release/*'
|
- 'release/*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@@ -63,7 +63,7 @@ jobs:
|
|||||||
if [[ "$last_char" =~ [0-3] ]]; then
|
if [[ "$last_char" =~ [0-3] ]]; then
|
||||||
gpu_id="$last_char"
|
gpu_id="$last_char"
|
||||||
else
|
else
|
||||||
gpu_id="0"
|
gpu_id="0"
|
||||||
fi
|
fi
|
||||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||||
@@ -84,4 +84,4 @@ jobs:
|
|||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
bash scripts/run_ci_xpu.sh
|
bash scripts/run_ci_xpu.sh
|
||||||
"
|
"
|
||||||
|
@@ -5,12 +5,27 @@ default_stages:
|
|||||||
- pre-commit # Run locally
|
- pre-commit # Run locally
|
||||||
# - manual # Run in CI
|
# - manual # Run in CI
|
||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/psf/black.git
|
||||||
|
rev: 22.8.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
files: \.(py|pyi)$
|
||||||
|
additional_dependencies: [toml]
|
||||||
|
# 自动排序
|
||||||
|
- repo: https://github.com/PyCQA/isort
|
||||||
|
rev: 5.11.5
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 4.0.1
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
# 代码检查
|
# 代码检查
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--output-format, github, --fix, --line-length=120]
|
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
||||||
# # 拼写检查
|
# # 拼写检查
|
||||||
# - repo: https://github.com/codespell-project/codespell
|
# - repo: https://github.com/codespell-project/codespell
|
||||||
# rev: v2.4.1
|
# rev: v2.4.1
|
||||||
@@ -18,17 +33,13 @@ repos:
|
|||||||
# - id: codespell
|
# - id: codespell
|
||||||
# additional_dependencies: ['tomli']
|
# additional_dependencies: ['tomli']
|
||||||
# args: ['--toml', 'pyproject.toml']
|
# args: ['--toml', 'pyproject.toml']
|
||||||
# 自动排序
|
|
||||||
- repo: https://github.com/PyCQA/isort
|
|
||||||
rev: 6.0.1
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
# markdown
|
# markdown
|
||||||
- repo: https://github.com/jackdewinter/pymarkdown
|
- repo: https://github.com/jackdewinter/pymarkdown
|
||||||
rev: v0.9.29
|
rev: v0.9.29
|
||||||
hooks:
|
hooks:
|
||||||
- id: pymarkdown
|
- id: pymarkdown
|
||||||
args: [fix]
|
args: ["-d", "MD029,MD031", fix]
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
|
@@ -8,7 +8,7 @@
|
|||||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -17,8 +17,8 @@
|
|||||||
|
|
|
|
||||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
||||||
|
|
|
|
||||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@@ -131,4 +131,4 @@ python benchmarks/benchmark_mtp.py \
|
|||||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||||
--dataset-path:测试数据集路径
|
--dataset-path:测试数据集路径
|
||||||
```
|
```
|
||||||
|
@@ -29,13 +29,13 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestFuncInput:
|
class RequestFuncInput:
|
||||||
"""Input for requesting LLMs via API"""
|
"""Input for requesting LLMs via API"""
|
||||||
|
|
||||||
no: int
|
no: int
|
||||||
prompt: str
|
prompt: str
|
||||||
history_QA: Optional[dict]
|
history_QA: Optional[dict]
|
||||||
@@ -55,6 +55,7 @@ class RequestFuncInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RequestFuncOutput:
|
class RequestFuncOutput:
|
||||||
"""Output for requesting LLMs via API"""
|
"""Output for requesting LLMs via API"""
|
||||||
|
|
||||||
no: int = 0
|
no: int = 0
|
||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
reasoning_content: str = ""
|
reasoning_content: str = ""
|
||||||
@@ -66,7 +67,7 @@ class RequestFuncOutput:
|
|||||||
itl: list = field(default_factory=list) # list of inter-token latencies
|
itl: list = field(default_factory=list) # list of inter-token latencies
|
||||||
tpot: float = 0.0 # avg next-token latencies
|
tpot: float = 0.0 # avg next-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
prompt_tokens: int = 0 # 推理侧返回输入token数
|
prompt_tokens: int = 0 # 推理侧返回输入token数
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
@@ -76,12 +77,9 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
"""Request an LLM using EB OpenAI"""
|
"""Request an LLM using EB OpenAI"""
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(
|
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
|
||||||
("completions", "profile")
|
|
||||||
), "OpenAI Chat Completions API URL must end with 'completions'."
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
if request_func_input.multi_modal_content:
|
if request_func_input.multi_modal_content:
|
||||||
content.append(request_func_input.multi_modal_content)
|
content.append(request_func_input.multi_modal_content)
|
||||||
@@ -91,7 +89,7 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {
|
"stream_options": {
|
||||||
"include_usage": True,
|
"include_usage": True,
|
||||||
"continuous_usage_stats": True
|
"continuous_usage_stats": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
# 超参由yaml传入
|
# 超参由yaml传入
|
||||||
@@ -99,8 +97,8 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
|
|
||||||
if request_func_input.ignore_eos:
|
if request_func_input.ignore_eos:
|
||||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
|
|
||||||
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
|
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -115,16 +113,14 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||||
headers=headers) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
"data: ")
|
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
# print("####chunk:", chunk, type(chunk))
|
# print("####chunk:", chunk, type(chunk))
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
@@ -138,22 +134,20 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
ttft = timestamp - st
|
ttft = timestamp - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
# cached_tokens
|
# cached_tokens
|
||||||
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
output.prompt_len = (
|
||||||
|
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||||
|
)
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
output.generated_text += content or ""
|
output.generated_text += content or ""
|
||||||
output.reasoning_content += reason_content or ""
|
output.reasoning_content += reason_content or ""
|
||||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||||
elif usage := data.get("usage", {}):
|
elif usage := data.get("usage", {}):
|
||||||
output.output_tokens = usage.get(
|
output.output_tokens = usage.get("completion_tokens", 0)
|
||||||
"completion_tokens", 0)
|
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
output.prompt_tokens = usage.get(
|
|
||||||
"prompt_tokens", 0)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@@ -166,7 +160,12 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
print("####error response:", error_text, "####payload:", payload)
|
print(
|
||||||
|
"####error response:",
|
||||||
|
error_text,
|
||||||
|
"####payload:",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
output.error = error_text or ""
|
output.error = error_text or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -194,15 +193,14 @@ async def async_request_eb_openai_completions(
|
|||||||
("completions", "profile")
|
("completions", "profile")
|
||||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model,
|
"model": request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {
|
"stream_options": {
|
||||||
"include_usage": True,
|
"include_usage": True,
|
||||||
"continuous_usage_stats": True
|
"continuous_usage_stats": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
# 超参由yaml传入
|
# 超参由yaml传入
|
||||||
@@ -210,12 +208,12 @@ async def async_request_eb_openai_completions(
|
|||||||
|
|
||||||
if request_func_input.ignore_eos:
|
if request_func_input.ignore_eos:
|
||||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
|
|
||||||
print("payload:", json.dumps(payload, ensure_ascii=False))
|
print("payload:", json.dumps(payload, ensure_ascii=False))
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
@@ -227,8 +225,7 @@ async def async_request_eb_openai_completions(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||||
headers=headers) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
first_chunk_received = False
|
first_chunk_received = False
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
@@ -236,8 +233,7 @@ async def async_request_eb_openai_completions(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
"data: ")
|
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
# print("####chunk:", chunk, chunk.usage)
|
# print("####chunk:", chunk, chunk.usage)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
@@ -250,7 +246,7 @@ async def async_request_eb_openai_completions(
|
|||||||
# Note that text could be empty here
|
# Note that text could be empty here
|
||||||
# e.g. for special tokens
|
# e.g. for special tokens
|
||||||
text = choices[0].get("text")
|
text = choices[0].get("text")
|
||||||
|
|
||||||
# First token
|
# First token
|
||||||
if not first_chunk_received:
|
if not first_chunk_received:
|
||||||
first_chunk_received = True
|
first_chunk_received = True
|
||||||
@@ -259,26 +255,23 @@ async def async_request_eb_openai_completions(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
generated_text += text or ""
|
generated_text += text or ""
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.prompt_tokens = usage.get(
|
output.prompt_tokens = usage.get("prompt_tokens")
|
||||||
"prompt_tokens")
|
output.output_tokens = usage.get("completion_tokens")
|
||||||
output.output_tokens = usage.get(
|
|
||||||
"completion_tokens")
|
|
||||||
if first_chunk_received:
|
if first_chunk_received:
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.success = False
|
output.success = False
|
||||||
output.error = (
|
output.error = (
|
||||||
"Never received a valid chunk to calculate TTFT."
|
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||||
"This response will be marked as failed!")
|
)
|
||||||
|
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
|
|
||||||
@@ -294,8 +287,8 @@ async def async_request_eb_openai_completions(
|
|||||||
output.success = False
|
output.success = False
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
output.error = "".join(traceback.format_exception(*exc_info))
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
print("final_output:{}".format(output))
|
print(f"final_output:{output}")
|
||||||
|
|
||||||
if pbar:
|
if pbar:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
@@ -310,8 +303,7 @@ async def async_request_tgi(
|
|||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
params = {
|
params = {
|
||||||
"max_new_tokens": request_func_input.output_len,
|
"max_new_tokens": request_func_input.output_len,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
@@ -358,8 +350,7 @@ async def async_request_tgi(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
output.arrival_time.append(data["arrival_time"])
|
output.arrival_time.append(data["arrival_time"])
|
||||||
@@ -388,8 +379,7 @@ async def async_request_trt_llm(
|
|||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"accumulate_tokens": True,
|
"accumulate_tokens": True,
|
||||||
"text_input": request_func_input.prompt,
|
"text_input": request_func_input.prompt,
|
||||||
@@ -414,8 +404,7 @@ async def async_request_trt_llm(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
||||||
"data:")
|
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
output.generated_text += data["text_output"]
|
output.generated_text += data["text_output"]
|
||||||
@@ -427,8 +416,7 @@ async def async_request_trt_llm(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@@ -453,8 +441,7 @@ async def async_request_deepspeed_mii(
|
|||||||
pbar: Optional[tqdm] = None,
|
pbar: Optional[tqdm] = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
"""Request an LLM using Deepspeed MII"""
|
"""Request an LLM using Deepspeed MII"""
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
@@ -472,19 +459,16 @@ async def async_request_deepspeed_mii(
|
|||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
async with session.post(url=request_func_input.api_url,
|
async with session.post(url=request_func_input.api_url, json=payload) as response:
|
||||||
json=payload) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
parsed_resp = await response.json()
|
parsed_resp = await response.json()
|
||||||
output.latency = time.perf_counter() - st
|
output.latency = time.perf_counter() - st
|
||||||
if "choices" in parsed_resp:
|
if "choices" in parsed_resp:
|
||||||
output.generated_text = parsed_resp["choices"][0][
|
output.generated_text = parsed_resp["choices"][0]["text"]
|
||||||
"text"]
|
|
||||||
elif "text" in parsed_resp:
|
elif "text" in parsed_resp:
|
||||||
output.generated_text = parsed_resp["text"][0]
|
output.generated_text = parsed_resp["text"][0]
|
||||||
else:
|
else:
|
||||||
output.error = ("Unexpected response format: "
|
output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
|
||||||
"neither 'choices' nor 'text' found")
|
|
||||||
output.success = False
|
output.success = False
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
@@ -510,26 +494,22 @@ async def async_request_openai_completions(
|
|||||||
("completions", "profile")
|
("completions", "profile")
|
||||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model_name \
|
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||||
if request_func_input.model_name else request_func_input.model,
|
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
# "temperature": 0.0,
|
# "temperature": 0.0,
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"logprobs": request_func_input.logprobs,
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
#"stream_options": {
|
# "stream_options": {
|
||||||
# "include_usage": True,
|
# "include_usage": True,
|
||||||
#},
|
# },
|
||||||
}
|
}
|
||||||
if request_func_input.ignore_eos:
|
if request_func_input.ignore_eos:
|
||||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
|
|
||||||
headers = {
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
||||||
}
|
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
@@ -538,8 +518,7 @@ async def async_request_openai_completions(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||||
headers=headers) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
first_chunk_received = False
|
first_chunk_received = False
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
@@ -547,8 +526,7 @@ async def async_request_openai_completions(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
"data: ")
|
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
# print("####chunk:", chunk, type(chunk))
|
# print("####chunk:", chunk, type(chunk))
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@@ -569,21 +547,19 @@ async def async_request_openai_completions(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text += text or ""
|
generated_text += text or ""
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.output_tokens = usage.get(
|
output.output_tokens = usage.get("completion_tokens")
|
||||||
"completion_tokens")
|
|
||||||
if first_chunk_received:
|
if first_chunk_received:
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.success = False
|
output.success = False
|
||||||
output.error = (
|
output.error = (
|
||||||
"Never received a valid chunk to calculate TTFT."
|
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
|
||||||
"This response will be marked as failed!")
|
)
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
else:
|
else:
|
||||||
@@ -606,25 +582,24 @@ async def async_request_openai_audio(
|
|||||||
"""Request an LLM using OpenAI"""
|
"""Request an LLM using OpenAI"""
|
||||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(
|
assert api_url.endswith(
|
||||||
("transcriptions", "translations"
|
("transcriptions", "translations")
|
||||||
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
), "OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||||
"or `translations`."
|
"or `translations`."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model_name \
|
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
|
||||||
if request_func_input.model_name else request_func_input.model,
|
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"max_completion_tokens": request_func_input.output_len,
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"language": "en",
|
"language": "en",
|
||||||
# Flattened due to multipart/form-data
|
# Flattened due to multipart/form-data
|
||||||
"stream_include_usage": True,
|
"stream_include_usage": True,
|
||||||
"stream_continuous_usage_stats": True
|
"stream_continuous_usage_stats": True,
|
||||||
}
|
}
|
||||||
if request_func_input.extra_body:
|
if request_func_input.extra_body:
|
||||||
payload.update(request_func_input.extra_body)
|
payload.update(request_func_input.extra_body)
|
||||||
@@ -639,9 +614,9 @@ async def async_request_openai_audio(
|
|||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
|
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||||
form = aiohttp.FormData()
|
form = aiohttp.FormData()
|
||||||
form.add_field('file', f, content_type='audio/wav')
|
form.add_field("file", f, content_type="audio/wav")
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
form.add_field(key, str(value))
|
form.add_field(key, str(value))
|
||||||
|
|
||||||
@@ -653,24 +628,20 @@ async def async_request_openai_audio(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url,
|
async with session.post(url=api_url, data=form, headers=headers) as response:
|
||||||
data=form,
|
|
||||||
headers=headers) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
"data: ")
|
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
if choices := data.get("choices"):
|
if choices := data.get("choices"):
|
||||||
content = choices[0]["delta"].get(
|
content = choices[0]["delta"].get("content")
|
||||||
"content")
|
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = timestamp - st
|
ttft = timestamp - st
|
||||||
@@ -678,13 +649,11 @@ async def async_request_openai_audio(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
timestamp - most_recent_timestamp)
|
|
||||||
|
|
||||||
generated_text += content or ""
|
generated_text += content or ""
|
||||||
elif usage := data.get("usage"):
|
elif usage := data.get("usage"):
|
||||||
output.output_tokens = usage.get(
|
output.output_tokens = usage.get("completion_tokens")
|
||||||
"completion_tokens")
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@@ -718,8 +687,11 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OPENAI_COMPATIBLE_BACKENDS = [
|
OPENAI_COMPATIBLE_BACKENDS = [
|
||||||
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
k
|
||||||
if v in (async_request_openai_completions,
|
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||||
async_request_eb_openai_chat_completions)
|
if v
|
||||||
|
in (
|
||||||
|
async_request_openai_completions,
|
||||||
|
async_request_eb_openai_chat_completions,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +38,7 @@ class SampleRequest:
|
|||||||
"""
|
"""
|
||||||
Represents a single inference request for benchmarking.
|
Represents a single inference request for benchmarking.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
no: int
|
no: int
|
||||||
prompt: Union[str, Any]
|
prompt: Union[str, Any]
|
||||||
history_QA: Union[str, Any]
|
history_QA: Union[str, Any]
|
||||||
@@ -48,6 +49,7 @@ class SampleRequest:
|
|||||||
|
|
||||||
class BenchmarkDataset(ABC):
|
class BenchmarkDataset(ABC):
|
||||||
"""BenchmarkDataset"""
|
"""BenchmarkDataset"""
|
||||||
|
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
IS_MULTIMODAL = False
|
IS_MULTIMODAL = False
|
||||||
|
|
||||||
@@ -68,8 +70,7 @@ class BenchmarkDataset(ABC):
|
|||||||
self.dataset_path = dataset_path
|
self.dataset_path = dataset_path
|
||||||
# Set the random seed, ensuring that a None value is replaced with the
|
# Set the random seed, ensuring that a None value is replaced with the
|
||||||
# default seed.
|
# default seed.
|
||||||
self.random_seed = (random_seed
|
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
|
||||||
if random_seed is not None else self.DEFAULT_SEED)
|
|
||||||
self.data = None
|
self.data = None
|
||||||
self.hyperparameter_path = hyperparameter_path
|
self.hyperparameter_path = hyperparameter_path
|
||||||
self.hyperparameters = {}
|
self.hyperparameters = {}
|
||||||
@@ -85,8 +86,7 @@ class BenchmarkDataset(ABC):
|
|||||||
NotImplementedError: If a subclass does not implement this method.
|
NotImplementedError: If a subclass does not implement this method.
|
||||||
"""
|
"""
|
||||||
# TODO (jenniferzhao): add support for downloading data
|
# TODO (jenniferzhao): add support for downloading data
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("load_data must be implemented in subclasses.")
|
||||||
"load_data must be implemented in subclasses.")
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sample(self, num_requests: int) -> list[SampleRequest]:
|
def sample(self, num_requests: int) -> list[SampleRequest]:
|
||||||
@@ -105,8 +105,7 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||||
|
|
||||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
|
||||||
num_requests: int) -> None:
|
|
||||||
"""
|
"""
|
||||||
Oversamples the list of requests if its size is less than the desired
|
Oversamples the list of requests if its size is less than the desired
|
||||||
number.
|
number.
|
||||||
@@ -117,11 +116,9 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
if len(requests) < num_requests:
|
if len(requests) < num_requests:
|
||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
additional = random.choices(requests,
|
additional = random.choices(requests, k=num_requests - len(requests))
|
||||||
k=num_requests - len(requests))
|
|
||||||
requests.extend(additional)
|
requests.extend(additional)
|
||||||
logger.info("Oversampled requests to reach %d total samples.",
|
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||||
num_requests)
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_sequence(
|
def is_valid_sequence(
|
||||||
@@ -141,14 +138,12 @@ def is_valid_sequence(
|
|||||||
"""
|
"""
|
||||||
# Check for invalid conditions
|
# Check for invalid conditions
|
||||||
prompt_too_short = prompt_len < min_len
|
prompt_too_short = prompt_len < min_len
|
||||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
|
||||||
< min_len)
|
|
||||||
prompt_too_long = prompt_len > max_prompt_len
|
prompt_too_long = prompt_len > max_prompt_len
|
||||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||||
|
|
||||||
# Return True if none of the invalid conditions are met
|
# Return True if none of the invalid conditions are met
|
||||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
|
||||||
or combined_too_long)
|
|
||||||
|
|
||||||
|
|
||||||
def process_image(image: Any) -> Mapping[str, Any]:
|
def process_image(image: Any) -> Mapping[str, Any]:
|
||||||
@@ -171,28 +166,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input is not a supported type.
|
ValueError: If the input is not a supported type.
|
||||||
"""
|
"""
|
||||||
if isinstance(image, dict) and 'bytes' in image:
|
if isinstance(image, dict) and "bytes" in image:
|
||||||
image = Image.open(BytesIO(image['bytes']))
|
image = Image.open(BytesIO(image["bytes"]))
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
with io.BytesIO() as image_data:
|
with io.BytesIO() as image_data:
|
||||||
image.save(image_data, format="JPEG")
|
image.save(image_data, format="JPEG")
|
||||||
image_base64 = base64.b64encode(
|
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||||
image_data.getvalue()).decode("utf-8")
|
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image_url = (image if image.startswith(
|
image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
|
||||||
("http://", "file://")) else f"file://{image}")
|
|
||||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
|
||||||
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
raise ValueError(
|
||||||
" or str or dictionary with raw image bytes.")
|
f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EBDataset(BenchmarkDataset):
|
class EBDataset(BenchmarkDataset):
|
||||||
@@ -243,8 +235,7 @@ class EBDataset(BenchmarkDataset):
|
|||||||
new_output_len = int(entry["max_dec_len"])
|
new_output_len = int(entry["max_dec_len"])
|
||||||
|
|
||||||
if enable_multimodal_chat:
|
if enable_multimodal_chat:
|
||||||
prompt = self.apply_multimodal_chat_transformation(
|
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||||
prompt, None)
|
|
||||||
samples.append(
|
samples.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
no=cnt,
|
no=cnt,
|
||||||
@@ -252,17 +243,20 @@ class EBDataset(BenchmarkDataset):
|
|||||||
prompt_len=self.prompt_len,
|
prompt_len=self.prompt_len,
|
||||||
history_QA=[],
|
history_QA=[],
|
||||||
expected_output_len=new_output_len,
|
expected_output_len=new_output_len,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
self.maybe_oversample_requests(samples, num_requests)
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
class EBChatDataset(BenchmarkDataset):
|
class EBChatDataset(BenchmarkDataset):
|
||||||
"""
|
"""
|
||||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||||
sample requests based on conversation turns.
|
sample requests based on conversation turns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt_len: int
|
prompt_len: int
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
@@ -296,8 +290,7 @@ class EBChatDataset(BenchmarkDataset):
|
|||||||
new_output_len = int(entry.get("max_tokens", 12288))
|
new_output_len = int(entry.get("max_tokens", 12288))
|
||||||
|
|
||||||
if enable_multimodal_chat:
|
if enable_multimodal_chat:
|
||||||
prompt = self.apply_multimodal_chat_transformation(
|
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||||
prompt, None)
|
|
||||||
samples.append(
|
samples.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
no=cnt,
|
no=cnt,
|
||||||
@@ -306,9 +299,9 @@ class EBChatDataset(BenchmarkDataset):
|
|||||||
prompt_len=0,
|
prompt_len=0,
|
||||||
history_QA=history_QA,
|
history_QA=history_QA,
|
||||||
expected_output_len=new_output_len,
|
expected_output_len=new_output_len,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
self.maybe_oversample_requests(samples, num_requests)
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@@ -18,28 +18,16 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import openai
|
from benchmark_dataset import EBChatDataset, EBDataset
|
||||||
import yaml
|
|
||||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
|
||||||
from benchmark_serving import benchmark
|
from benchmark_serving import benchmark
|
||||||
|
|
||||||
|
|
||||||
def prepare_input_requests(
|
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
|
||||||
num_prompts: int, dataset_name: str, dataset_path: str
|
|
||||||
) -> Union[EBDataset, EBChatDataset]:
|
|
||||||
dataset_mapping = {
|
dataset_mapping = {
|
||||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
|
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||||
num_requests=num_prompts
|
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
|
||||||
),
|
|
||||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
|
|
||||||
num_requests=num_prompts
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
|||||||
def main(args):
|
def main(args):
|
||||||
base_url = f"http://{args.host}:{args.port}"
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
input_requests = prepare_input_requests(
|
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
|
||||||
args.num_prompts, args.dataset_name, args.dataset_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||||
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
|
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||||
|
|
||||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||||
# Wramup
|
# Wramup
|
||||||
print("Starting warmup...")
|
print("Starting warmup...")
|
||||||
with open(os.devnull, "w") as f:
|
with open(os.devnull, "w") as f:
|
||||||
with contextlib.redirect_stdout(f):
|
with contextlib.redirect_stdout(f):
|
||||||
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
|
send_one_batch(
|
||||||
|
base_url,
|
||||||
|
max_concurrency,
|
||||||
|
input_requests[0:max_concurrency],
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark
|
||||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||||
|
|
||||||
metric_header = f"Speed up"
|
metric_header = "Speed up"
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||||
for draft_token_step in args.draft_token_steps:
|
for draft_token_step in args.draft_token_steps:
|
||||||
speedup = calculate_speedup(
|
speedup = calculate_speedup(
|
||||||
@@ -130,11 +121,7 @@ def main(args):
|
|||||||
s_itl,
|
s_itl,
|
||||||
record["mean_s_itl_ms"],
|
record["mean_s_itl_ms"],
|
||||||
)
|
)
|
||||||
print(
|
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
|
||||||
"{:<40} {:<10.2f}".format(
|
|
||||||
f"Speed up on {draft_token_step} steps draft", speedup
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -24,9 +24,11 @@ import os
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def convert_to_pytorch_benchmark_format(
|
||||||
metrics: dict[str, list],
|
args: argparse.Namespace,
|
||||||
extra_info: dict[str, Any]) -> list:
|
metrics: dict[str, list],
|
||||||
|
extra_info: dict[str, Any],
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||||
on metric per record
|
on metric per record
|
||||||
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||||
"tensor_parallel_size")
|
|
||||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||||
if not tp and "tensor_parallel_size" in extra_info:
|
if not tp and "tensor_parallel_size" in extra_info:
|
||||||
record["benchmark"]["extra_info"]["args"][
|
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
|
||||||
|
|
||||||
records.append(record)
|
records.append(record)
|
||||||
|
|
||||||
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
|||||||
|
|
||||||
class InfEncoder(json.JSONEncoder):
|
class InfEncoder(json.JSONEncoder):
|
||||||
"""InfEncoder"""
|
"""InfEncoder"""
|
||||||
|
|
||||||
def clear_inf(self, o: Any):
|
def clear_inf(self, o: Any):
|
||||||
"""clear_inf"""
|
"""clear_inf"""
|
||||||
if isinstance(o, dict):
|
if isinstance(o, dict):
|
||||||
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
|
|||||||
"""write_to_json"""
|
"""write_to_json"""
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
json.dump(records, f, cls=InfEncoder)
|
json.dump(records, f, cls=InfEncoder)
|
||||||
|
|
||||||
|
@@ -25,32 +25,32 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
import yaml
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
import requests
|
|
||||||
import copy
|
|
||||||
from collections.abc import AsyncGenerator, Iterable
|
from collections.abc import AsyncGenerator, Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
import requests
|
||||||
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
import yaml
|
||||||
RequestFuncOutput)
|
from backend_request_func import (
|
||||||
|
ASYNC_REQUEST_FUNCS,
|
||||||
|
OPENAI_COMPATIBLE_BACKENDS,
|
||||||
|
RequestFuncInput,
|
||||||
|
RequestFuncOutput,
|
||||||
|
)
|
||||||
|
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||||
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
|
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
|
||||||
|
|
||||||
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
|
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkMetrics:
|
class BenchmarkMetrics:
|
||||||
"""Class containing all metrics that are used in this script"""
|
"""Class containing all metrics that are used in this script"""
|
||||||
|
|
||||||
completed: int
|
completed: int
|
||||||
total_input: int
|
total_input: int
|
||||||
total_output: int
|
total_output: int
|
||||||
@@ -133,8 +133,7 @@ async def get_request(
|
|||||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||||
|
|
||||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||||
assert burstiness > 0, (
|
assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
|
||||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
|
||||||
theta = 1.0 / (request_rate * burstiness)
|
theta = 1.0 / (request_rate * burstiness)
|
||||||
|
|
||||||
for request in input_requests:
|
for request in input_requests:
|
||||||
@@ -160,7 +159,7 @@ def calculate_metrics(
|
|||||||
) -> tuple[BenchmarkMetrics, list[int]]:
|
) -> tuple[BenchmarkMetrics, list[int]]:
|
||||||
"""Calculates various performance metrics based on the inputs and outputs."""
|
"""Calculates various performance metrics based on the inputs and outputs."""
|
||||||
input_lens: list[int] = []
|
input_lens: list[int] = []
|
||||||
infer_input_lens: list[int] = [] # 推理侧输入token数
|
infer_input_lens: list[int] = [] # 推理侧输入token数
|
||||||
actual_output_lens: list[int] = []
|
actual_output_lens: list[int] = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
completed = 0
|
completed = 0
|
||||||
@@ -210,8 +209,9 @@ def calculate_metrics(
|
|||||||
s_e2els.append(outputs[i].arrival_time[-1])
|
s_e2els.append(outputs[i].arrival_time[-1])
|
||||||
# 解码速度去掉首token
|
# 解码速度去掉首token
|
||||||
if len(outputs[i].arrival_time) > 2:
|
if len(outputs[i].arrival_time) > 2:
|
||||||
s_decodes.append((outputs[i].output_tokens - 1) /
|
s_decodes.append(
|
||||||
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
|
(outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
|
||||||
|
)
|
||||||
completed += 1
|
completed += 1
|
||||||
else:
|
else:
|
||||||
actual_output_lens.append(0)
|
actual_output_lens.append(0)
|
||||||
@@ -224,16 +224,13 @@ def calculate_metrics(
|
|||||||
|
|
||||||
if "ttft" in goodput_config_dict:
|
if "ttft" in goodput_config_dict:
|
||||||
valid_metrics.append(ttfts)
|
valid_metrics.append(ttfts)
|
||||||
slo_values.append(goodput_config_dict["ttft"] /
|
slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
|
||||||
if "tpot" in goodput_config_dict:
|
if "tpot" in goodput_config_dict:
|
||||||
valid_metrics.append(all_tpots)
|
valid_metrics.append(all_tpots)
|
||||||
slo_values.append(goodput_config_dict["tpot"] /
|
slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
|
||||||
if "e2el" in goodput_config_dict:
|
if "e2el" in goodput_config_dict:
|
||||||
valid_metrics.append(e2els)
|
valid_metrics.append(e2els)
|
||||||
slo_values.append(goodput_config_dict["e2el"] /
|
slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
|
||||||
|
|
||||||
for req_metric in zip(*valid_metrics):
|
for req_metric in zip(*valid_metrics):
|
||||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||||
@@ -242,9 +239,9 @@ def calculate_metrics(
|
|||||||
|
|
||||||
if completed == 0:
|
if completed == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"All requests failed. This is likely due to a misconfiguration "
|
"All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
|
||||||
"on the benchmark arguments.",
|
stacklevel=2,
|
||||||
stacklevel=2)
|
)
|
||||||
metrics = BenchmarkMetrics(
|
metrics = BenchmarkMetrics(
|
||||||
completed=completed,
|
completed=completed,
|
||||||
total_input=total_input,
|
total_input=total_input,
|
||||||
@@ -253,64 +250,50 @@ def calculate_metrics(
|
|||||||
request_goodput=good_completed / dur_s,
|
request_goodput=good_completed / dur_s,
|
||||||
output_throughput=sum(actual_output_lens) / dur_s,
|
output_throughput=sum(actual_output_lens) / dur_s,
|
||||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||||
mean_s_decode=np.mean(s_decodes or 0) *
|
mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
|
||||||
1, # ttfts is empty if streaming is not supported by backend
|
|
||||||
std_s_decode=np.std(s_decodes or 0) * 1,
|
std_s_decode=np.std(s_decodes or 0) * 1,
|
||||||
median_s_decode=np.median(s_decodes or 0) * 1,
|
median_s_decode=np.median(s_decodes or 0) * 1,
|
||||||
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1)
|
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
|
||||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
|
||||||
1000, # ttfts is empty if streaming is not supported by backend
|
|
||||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
|
||||||
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
|
|
||||||
1000, # ttfts is empty if streaming is not supported by backend
|
|
||||||
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
|
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
|
||||||
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
|
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
|
||||||
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000)
|
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
std_itl_ms=np.std(itls or 0) * 1000,
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
|
mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
|
||||||
std_s_itl_ms=np.std(s_itls or 0) * 1000,
|
std_s_itl_ms=np.std(s_itls or 0) * 1000,
|
||||||
median_s_itl_ms=np.median(s_itls or 0) * 1000,
|
median_s_itl_ms=np.median(s_itls or 0) * 1000,
|
||||||
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000)
|
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
|
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
|
||||||
std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
|
std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
|
||||||
median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
|
median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
|
||||||
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000)
|
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_input_len=np.mean(input_lens or 0) * 1,
|
mean_input_len=np.mean(input_lens or 0) * 1,
|
||||||
std_input_len=np.std(input_lens or 0) * 1,
|
std_input_len=np.std(input_lens or 0) * 1,
|
||||||
median_input_len=np.median(input_lens or 0) * 1,
|
median_input_len=np.median(input_lens or 0) * 1,
|
||||||
percentiles_input_len=[(p, np.percentile(input_lens or 0, p))
|
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
|
mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
|
||||||
std_s_input_len=np.std(infer_input_lens or 0) * 1,
|
std_s_input_len=np.std(infer_input_lens or 0) * 1,
|
||||||
median_s_input_len=np.median(infer_input_lens or 0) * 1,
|
median_s_input_len=np.median(infer_input_lens or 0) * 1,
|
||||||
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p))
|
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
mean_output_len=np.mean(actual_output_lens or 0) * 1,
|
mean_output_len=np.mean(actual_output_lens or 0) * 1,
|
||||||
std_output_len=np.std(actual_output_lens or 0) * 1,
|
std_output_len=np.std(actual_output_lens or 0) * 1,
|
||||||
median_output_len=np.median(actual_output_lens or 0) * 1,
|
median_output_len=np.median(actual_output_lens or 0) * 1,
|
||||||
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p))
|
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
|
||||||
for p in selected_percentiles],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
return metrics, actual_output_lens
|
||||||
@@ -351,20 +334,22 @@ async def benchmark(
|
|||||||
|
|
||||||
if lora_modules:
|
if lora_modules:
|
||||||
# For each input request, choose a LoRA module at random.
|
# For each input request, choose a LoRA module at random.
|
||||||
lora_modules = iter(
|
lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
|
||||||
[random.choice(lora_modules) \
|
|
||||||
for _ in range(len(input_requests))])
|
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
print("Starting profiler...")
|
print("Starting profiler...")
|
||||||
profile_input = RequestFuncInput(model=model_id,
|
test_prompt = None
|
||||||
model_name=model_name,
|
test_output_len = None
|
||||||
prompt=test_prompt,
|
profile_input = RequestFuncInput(
|
||||||
api_url=base_url + "/start_profile",
|
model=model_id,
|
||||||
output_len=test_output_len,
|
model_name=model_name,
|
||||||
logprobs=logprobs,
|
prompt=test_prompt,
|
||||||
ignore_eos=ignore_eos,
|
api_url=base_url + "/start_profile",
|
||||||
extra_body=extra_body)
|
output_len=test_output_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
|
extra_body=extra_body,
|
||||||
|
)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
@@ -384,19 +369,16 @@ async def benchmark(
|
|||||||
# and it will simplify the code in limited_request_func.
|
# and it will simplify the code in limited_request_func.
|
||||||
# semaphore = (asyncio.Semaphore(max_concurrency)
|
# semaphore = (asyncio.Semaphore(max_concurrency)
|
||||||
# if max_concurrency else contextlib.nullcontext())
|
# if max_concurrency else contextlib.nullcontext())
|
||||||
semaphore = (asyncio.Semaphore(max_concurrency)
|
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
|
||||||
if max_concurrency else None)
|
|
||||||
|
|
||||||
async def limited_request_func(request_func_input, pbar):
|
async def limited_request_func(request_func_input, pbar):
|
||||||
if semaphore is None:
|
if semaphore is None:
|
||||||
return await request_func(request_func_input=request_func_input,
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
pbar=pbar)
|
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await request_func(request_func_input=request_func_input,
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
pbar=pbar)
|
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
|
|
||||||
print(f"开始时间:{datetime.now()}")
|
print(f"开始时间:{datetime.now()}")
|
||||||
tasks: list[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate, burstiness):
|
async for request in get_request(input_requests, request_rate, burstiness):
|
||||||
@@ -409,25 +391,26 @@ async def benchmark(
|
|||||||
req_lora_module = next(lora_modules)
|
req_lora_module = next(lora_modules)
|
||||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||||
|
|
||||||
request_func_input = RequestFuncInput(model=req_model_id,
|
request_func_input = RequestFuncInput(
|
||||||
model_name=req_model_name,
|
model=req_model_id,
|
||||||
prompt=prompt,
|
model_name=req_model_name,
|
||||||
prompt_len=0,
|
prompt=prompt,
|
||||||
history_QA=history_QA,
|
prompt_len=0,
|
||||||
hyper_parameters=hyper_parameters,
|
history_QA=history_QA,
|
||||||
api_url=api_url,
|
hyper_parameters=hyper_parameters,
|
||||||
output_len=output_len,
|
api_url=api_url,
|
||||||
logprobs=logprobs,
|
output_len=output_len,
|
||||||
ignore_eos=ignore_eos,
|
logprobs=logprobs,
|
||||||
extra_body=extra_body)
|
ignore_eos=ignore_eos,
|
||||||
tasks.append(
|
extra_body=extra_body,
|
||||||
asyncio.create_task(
|
)
|
||||||
limited_request_func(request_func_input=request_func_input,
|
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||||
pbar=pbar)))
|
|
||||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
print(f"完成时间:{datetime.now()}")
|
print(f"完成时间:{datetime.now()}")
|
||||||
if profile:
|
if profile:
|
||||||
print("Stopping profiler...")
|
print("Stopping profiler...")
|
||||||
|
test_output_len = None
|
||||||
|
test_output_len = None
|
||||||
profile_input = RequestFuncInput(
|
profile_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_prompt,
|
prompt=test_prompt,
|
||||||
@@ -454,22 +437,16 @@ async def benchmark(
|
|||||||
)
|
)
|
||||||
print("Benchmark complete!!!")
|
print("Benchmark complete!!!")
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||||
benchmark_duration))
|
|
||||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||||
metrics.total_output))
|
print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
|
||||||
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
|
|
||||||
metrics.request_throughput))
|
|
||||||
if goodput_config_dict:
|
if goodput_config_dict:
|
||||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
|
||||||
metrics.request_goodput))
|
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
|
||||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
|
||||||
metrics.output_throughput))
|
|
||||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
|
||||||
metrics.total_token_throughput))
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"duration": benchmark_duration,
|
"duration": benchmark_duration,
|
||||||
@@ -477,8 +454,7 @@ async def benchmark(
|
|||||||
"total_input_tokens": metrics.total_input,
|
"total_input_tokens": metrics.total_input,
|
||||||
"total_output_tokens": metrics.total_output,
|
"total_output_tokens": metrics.total_output,
|
||||||
"request_throughput": metrics.request_throughput,
|
"request_throughput": metrics.request_throughput,
|
||||||
"request_goodput:":
|
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
|
||||||
metrics.request_goodput if goodput_config_dict else None,
|
|
||||||
"output_throughput": metrics.output_throughput,
|
"output_throughput": metrics.output_throughput,
|
||||||
"total_token_throughput": metrics.total_token_throughput,
|
"total_token_throughput": metrics.total_token_throughput,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
@@ -491,7 +467,6 @@ async def benchmark(
|
|||||||
"reasoning_contents": [output.reasoning_content for output in outputs],
|
"reasoning_contents": [output.reasoning_content for output in outputs],
|
||||||
"errors": [output.error for output in outputs],
|
"errors": [output.error for output in outputs],
|
||||||
}
|
}
|
||||||
quick_result = copy.deepcopy(result)
|
|
||||||
|
|
||||||
def process_one_metric(
|
def process_one_metric(
|
||||||
# E.g., "ttft"
|
# E.g., "ttft"
|
||||||
@@ -505,24 +480,25 @@ async def benchmark(
|
|||||||
# metric.
|
# metric.
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||||
print("{:<40} {:<10.2f}".format(
|
print(
|
||||||
f"Mean {metric_name} (ms):",
|
"{:<40} {:<10.2f}".format(
|
||||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
f"Mean {metric_name} (ms):",
|
||||||
print("{:<40} {:<10.2f}".format(
|
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
|
||||||
f"Median {metric_name} (ms):",
|
)
|
||||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
)
|
||||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
print(
|
||||||
metrics, f"mean_{metric_attribute_name}_ms")
|
"{:<40} {:<10.2f}".format(
|
||||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
f"Median {metric_name} (ms):",
|
||||||
metrics, f"median_{metric_attribute_name}_ms")
|
getattr(metrics, f"median_{metric_attribute_name}_ms"),
|
||||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
)
|
||||||
metrics, f"std_{metric_attribute_name}_ms")
|
)
|
||||||
for p, value in getattr(metrics,
|
result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
|
||||||
f"percentiles_{metric_attribute_name}_ms"):
|
result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
|
||||||
|
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
|
||||||
p_word = str(int(p)) if int(p) == p else str(p)
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||||
value))
|
|
||||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||||
|
|
||||||
def process_one_length(
|
def process_one_length(
|
||||||
@@ -537,31 +513,31 @@ async def benchmark(
|
|||||||
# metric.
|
# metric.
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||||
print("{:<40} {:<10.2f}".format(
|
print(
|
||||||
f"Mean {metric_name}:",
|
"{:<40} {:<10.2f}".format(
|
||||||
getattr(metrics, f"mean_{metric_attribute_name}")))
|
f"Mean {metric_name}:",
|
||||||
print("{:<40} {:<10.2f}".format(
|
getattr(metrics, f"mean_{metric_attribute_name}"),
|
||||||
f"Median {metric_name}:",
|
)
|
||||||
getattr(metrics, f"median_{metric_attribute_name}")))
|
)
|
||||||
result[f"mean_{metric_attribute_name}"] = getattr(
|
print(
|
||||||
metrics, f"mean_{metric_attribute_name}")
|
"{:<40} {:<10.2f}".format(
|
||||||
result[f"median_{metric_attribute_name}"] = getattr(
|
f"Median {metric_name}:",
|
||||||
metrics, f"median_{metric_attribute_name}")
|
getattr(metrics, f"median_{metric_attribute_name}"),
|
||||||
result[f"std_{metric_attribute_name}"] = getattr(
|
)
|
||||||
metrics, f"std_{metric_attribute_name}")
|
)
|
||||||
for p, value in getattr(metrics,
|
result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
|
||||||
f"percentiles_{metric_attribute_name}"):
|
result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
|
||||||
|
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
|
||||||
p_word = str(int(p)) if int(p) == p else str(p)
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
|
||||||
value))
|
|
||||||
result[f"p{p_word}_{metric_attribute_name}"] = value
|
result[f"p{p_word}_{metric_attribute_name}"] = value
|
||||||
|
|
||||||
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
|
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
|
||||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||||
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
||||||
process_one_metric("tpot", "TPOT",
|
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||||
"Time per Output Token (excl. 1st token)")
|
|
||||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||||
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
||||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
@@ -581,6 +557,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
|
|||||||
"""
|
"""
|
||||||
快速评估
|
快速评估
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def process_quick_metric(
|
def process_quick_metric(
|
||||||
metric_attribute_name: str,
|
metric_attribute_name: str,
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
@@ -588,7 +565,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
|
|||||||
):
|
):
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||||
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
|
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
|
||||||
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
|
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
|
||||||
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
|
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
|
||||||
@@ -600,17 +577,17 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
|
|||||||
):
|
):
|
||||||
if metric_attribute_name not in selected_percentile_metrics:
|
if metric_attribute_name not in selected_percentile_metrics:
|
||||||
return
|
return
|
||||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||||
mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
|
mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
|
||||||
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
|
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
|
||||||
quick_result[f"mean_{metric_attribute_name}"] = mean_value
|
quick_result[f"mean_{metric_attribute_name}"] = mean_value
|
||||||
|
|
||||||
print("\n\n\n")
|
print("\n\n\n")
|
||||||
print("{s:{c}^{n}}".format(s=' Benchmark Quick Summary ', n=50, c='='))
|
print("{s:{c}^{n}}".format(s=" Benchmark Quick Summary ", n=50, c="="))
|
||||||
process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
|
process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
|
||||||
process_quick_metric("ttft", "TTFT", "Time to First Token")
|
process_quick_metric("ttft", "TTFT", "Time to First Token")
|
||||||
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
||||||
process_quick_metric("tpot", "TPOT",
|
process_quick_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||||
"Time per Output Token (excl. 1st token)")
|
|
||||||
process_quick_metric("itl", "ITL", "Inter-token Latency")
|
process_quick_metric("itl", "ITL", "Inter-token Latency")
|
||||||
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
||||||
process_quick_metric("e2el", "E2EL", "End-to-end Latency")
|
process_quick_metric("e2el", "E2EL", "End-to-end Latency")
|
||||||
@@ -633,12 +610,14 @@ def check_goodput_args(args):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective name should be one of "
|
"The service level objective name should be one of "
|
||||||
f"{str(VALID_NAMES)}. ")
|
f"{VALID_NAMES!s}. "
|
||||||
|
)
|
||||||
if slo_val < 0:
|
if slo_val < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||||
"The service level objective value should be "
|
"The service level objective value should be "
|
||||||
"non-negative.")
|
"non-negative."
|
||||||
|
)
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -652,37 +631,43 @@ def parse_goodput(slo_pairs):
|
|||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
"Invalid format found for service level objectives. "
|
"Invalid format found for service level objectives. "
|
||||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
'Specify service level objectives for goodput as "KEY:VALUE" '
|
||||||
"pairs, where the key is a metric name, and the value is a "
|
"pairs, where the key is a metric name, and the value is a "
|
||||||
"number in milliseconds.") from err
|
"number in milliseconds."
|
||||||
|
) from err
|
||||||
return goodput_config_dict
|
return goodput_config_dict
|
||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
|
||||||
results: dict[str, Any],
|
|
||||||
file_name: str) -> None:
|
|
||||||
"""Save the benchmarking results to PyTorch Benchmark Format JSON file"""
|
"""Save the benchmarking results to PyTorch Benchmark Format JSON file"""
|
||||||
metrics = [
|
metrics = [
|
||||||
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
"median_ttft_ms",
|
||||||
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
|
"mean_ttft_ms",
|
||||||
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
|
"std_ttft_ms",
|
||||||
|
"p99_ttft_ms",
|
||||||
|
"mean_tpot_ms",
|
||||||
|
"median_tpot_ms",
|
||||||
|
"std_tpot_ms",
|
||||||
|
"p99_tpot_ms",
|
||||||
|
"median_itl_ms",
|
||||||
|
"mean_itl_ms",
|
||||||
|
"std_itl_ms",
|
||||||
|
"p99_itl_ms",
|
||||||
]
|
]
|
||||||
# These raw data might be useful, but they are rather big. They can be added
|
# These raw data might be useful, but they are rather big. They can be added
|
||||||
# later if needed
|
# later if needed
|
||||||
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
|
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={k: [results[k]]
|
metrics={k: [results[k]] for k in metrics},
|
||||||
for k in metrics},
|
extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
|
||||||
extra_info={
|
)
|
||||||
k: results[k]
|
|
||||||
for k in results if k not in metrics and k not in ignored_metrics
|
|
||||||
})
|
|
||||||
if pt_records:
|
if pt_records:
|
||||||
# Don't use json suffix here as we don't want CI to pick it up
|
# Don't use json suffix here as we don't want CI to pick it up
|
||||||
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
||||||
write_to_json(pt_file, pt_records)
|
write_to_json(pt_file, pt_records)
|
||||||
|
|
||||||
|
|
||||||
def check_health(api_base_url: str) -> bool:
|
def check_health(api_base_url: str) -> bool:
|
||||||
health_url = api_base_url.rstrip("/") + "/health"
|
health_url = api_base_url.rstrip("/") + "/health"
|
||||||
try:
|
try:
|
||||||
@@ -697,6 +682,7 @@ def check_health(api_base_url: str) -> bool:
|
|||||||
print(f"[HEALTH] Failed to connect to {health_url}: {e}")
|
print(f"[HEALTH] Failed to connect to {health_url}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
"""Main entry point"""
|
"""Main entry point"""
|
||||||
print(args)
|
print(args)
|
||||||
@@ -707,7 +693,6 @@ def main(args: argparse.Namespace):
|
|||||||
model_id = args.model
|
model_id = args.model
|
||||||
model_name = args.served_model_name
|
model_name = args.served_model_name
|
||||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
tokenizer_mode = args.tokenizer_mode
|
|
||||||
|
|
||||||
if args.base_url is not None:
|
if args.base_url is not None:
|
||||||
api_url = f"{args.base_url}{args.endpoint}"
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
@@ -717,23 +702,17 @@ def main(args: argparse.Namespace):
|
|||||||
base_url = f"http://{args.host}:{args.port}"
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
if args.dataset_name is None:
|
if args.dataset_name is None:
|
||||||
raise ValueError(
|
raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
|
||||||
"Please specify '--dataset-name' and the corresponding "
|
|
||||||
"'--dataset-path' if required.")
|
|
||||||
|
|
||||||
# For datasets that follow a similar structure, use a mapping.
|
# For datasets that follow a similar structure, use a mapping.
|
||||||
dataset_mapping = {
|
dataset_mapping = {
|
||||||
"EB":
|
"EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
|
||||||
lambda: EBDataset(random_seed=args.seed,
|
num_requests=args.num_prompts,
|
||||||
dataset_path=args.dataset_path).sample(
|
output_len=args.sharegpt_output_len,
|
||||||
num_requests=args.num_prompts,
|
|
||||||
output_len=args.sharegpt_output_len,
|
|
||||||
),
|
),
|
||||||
"EBChat":
|
"EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
|
||||||
lambda: EBChatDataset(random_seed=args.seed,
|
num_requests=args.num_prompts,
|
||||||
dataset_path=args.dataset_path).sample(
|
output_len=args.sharegpt_output_len,
|
||||||
num_requests=args.num_prompts,
|
|
||||||
output_len=args.sharegpt_output_len,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -751,15 +730,14 @@ def main(args: argparse.Namespace):
|
|||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
||||||
"top_k": args.top_k,
|
"top_k": args.top_k,
|
||||||
"min_p": args.min_p,
|
"min_p": args.min_p,
|
||||||
"temperature": args.temperature
|
"temperature": args.temperature,
|
||||||
}.items() if v is not None
|
}.items()
|
||||||
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sampling parameters are only supported by openai-compatible backend.
|
# Sampling parameters are only supported by openai-compatible backend.
|
||||||
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||||
raise ValueError(
|
raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
|
||||||
"Sampling parameters are only supported by openai-compatible "
|
|
||||||
"backends.")
|
|
||||||
|
|
||||||
if "temperature" not in sampling_params:
|
if "temperature" not in sampling_params:
|
||||||
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||||
@@ -790,15 +768,14 @@ def main(args: argparse.Namespace):
|
|||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||||
selected_percentiles=[
|
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
|
||||||
float(p) for p in args.metric_percentiles.split(",")
|
|
||||||
],
|
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
extra_body=sampling_params,
|
extra_body=sampling_params,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
if args.save_result:
|
if args.save_result:
|
||||||
@@ -819,22 +796,23 @@ def main(args: argparse.Namespace):
|
|||||||
kvstring = item.split("=")
|
kvstring = item.split("=")
|
||||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
|
||||||
"Invalid metadata format. Please use KEY=VALUE format."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not args.save_detailed:
|
if not args.save_detailed:
|
||||||
# Remove fields with too many data points
|
# Remove fields with too many data points
|
||||||
for field in [
|
for field in [
|
||||||
"input_lens", "output_lens", "ttfts", "itls",
|
"input_lens",
|
||||||
"generated_texts", "errors"
|
"output_lens",
|
||||||
|
"ttfts",
|
||||||
|
"itls",
|
||||||
|
"generated_texts",
|
||||||
|
"errors",
|
||||||
]:
|
]:
|
||||||
if field in result_json:
|
if field in result_json:
|
||||||
del result_json[field]
|
del result_json[field]
|
||||||
|
|
||||||
# Traffic
|
# Traffic
|
||||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
|
||||||
< float("inf") else "inf")
|
|
||||||
result_json["burstiness"] = args.burstiness
|
result_json["burstiness"] = args.burstiness
|
||||||
result_json["max_concurrency"] = args.max_concurrency
|
result_json["max_concurrency"] = args.max_concurrency
|
||||||
|
|
||||||
@@ -843,21 +821,19 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
base_model_id = model_id.split("/")[-1]
|
base_model_id = model_id.split("/")[-1]
|
||||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
|
||||||
if args.max_concurrency is not None else "")
|
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
|
||||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
|
|
||||||
if args.result_filename:
|
if args.result_filename:
|
||||||
file_name = args.result_filename
|
file_name = args.result_filename
|
||||||
if args.result_dir:
|
if args.result_dir:
|
||||||
file_name = os.path.join(args.result_dir, file_name)
|
file_name = os.path.join(args.result_dir, file_name)
|
||||||
with open(file_name, "w", encoding='utf-8') as outfile:
|
with open(file_name, "w", encoding="utf-8") as outfile:
|
||||||
json.dump(result_json, outfile)
|
json.dump(result_json, outfile)
|
||||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
|
||||||
description="Benchmark the online serving throughput.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -883,18 +859,29 @@ if __name__ == "__main__":
|
|||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="sharegpt",
|
default="sharegpt",
|
||||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"],
|
choices=[
|
||||||
|
"sharegpt",
|
||||||
|
"burstgpt",
|
||||||
|
"sonnet",
|
||||||
|
"random",
|
||||||
|
"hf",
|
||||||
|
"EB",
|
||||||
|
"EBChat",
|
||||||
|
],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dataset-path",
|
parser.add_argument(
|
||||||
type=str,
|
"--dataset-path",
|
||||||
default=None,
|
type=str,
|
||||||
help="Path to the sharegpt/sonnet dataset. "
|
default=None,
|
||||||
"Or the huggingface dataset ID if using HF dataset.")
|
help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
|
||||||
parser.add_argument("--hyperparameter-path",
|
)
|
||||||
type=str,
|
parser.add_argument(
|
||||||
default=None,
|
"--hyperparameter-path",
|
||||||
help="Path to the hyperparameter. ")
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the hyperparameter. ",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrency",
|
"--max-concurrency",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -906,7 +893,8 @@ if __name__ == "__main__":
|
|||||||
"initiated, this argument will control how many are actually allowed "
|
"initiated, this argument will control how many are actually allowed "
|
||||||
"to execute at a time. This means that when used in combination, the "
|
"to execute at a time. This means that when used in combination, the "
|
||||||
"actual request rate may be lower than specified with --request-rate, "
|
"actual request rate may be lower than specified with --request-rate, "
|
||||||
"if the server is not processing requests fast enough to keep up.")
|
"if the server is not processing requests fast enough to keep up.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@@ -917,7 +905,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help="Name or path of the tokenizer, if not using the default tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -930,11 +918,13 @@ if __name__ == "__main__":
|
|||||||
"--logprobs",
|
"--logprobs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help=("Number of logprobs-per-token to compute & return as part of "
|
help=(
|
||||||
"the request. If unspecified, then either (1) if beam search "
|
"Number of logprobs-per-token to compute & return as part of "
|
||||||
"is disabled, no logprobs are computed & a single dummy "
|
"the request. If unspecified, then either (1) if beam search "
|
||||||
"logprob is returned for each token; or (2) if beam search "
|
"is disabled, no logprobs are computed & a single dummy "
|
||||||
"is enabled 1 logprob per token is computed"),
|
"logprob is returned for each token; or (2) if beam search "
|
||||||
|
"is enabled 1 logprob per token is computed"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--request-rate",
|
"--request-rate",
|
||||||
@@ -971,8 +961,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile",
|
"--profile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use Torch Profiler. The endpoint must be launched with "
|
help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-result",
|
"--save-result",
|
||||||
@@ -1013,35 +1002,38 @@ if __name__ == "__main__":
|
|||||||
"--ignore-eos",
|
"--ignore-eos",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Set ignore_eos flag when sending the benchmark request."
|
help="Set ignore_eos flag when sending the benchmark request."
|
||||||
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
|
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--percentile-metrics",
|
"--percentile-metrics",
|
||||||
type=str,
|
type=str,
|
||||||
default="ttft,tpot,itl",
|
default="ttft,tpot,itl",
|
||||||
help="Comma-separated list of selected metrics to report percentils. "
|
help="Comma-separated list of selected metrics to report percentils. "
|
||||||
"This argument specifies the metrics to report percentiles. "
|
"This argument specifies the metrics to report percentiles. "
|
||||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
|
||||||
"Default value is \"ttft,tpot,itl\".")
|
'Default value is "ttft,tpot,itl".',
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metric-percentiles",
|
"--metric-percentiles",
|
||||||
type=str,
|
type=str,
|
||||||
default="99",
|
default="99",
|
||||||
help="Comma-separated list of percentiles for selected metrics. "
|
help="Comma-separated list of percentiles for selected metrics. "
|
||||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
|
||||||
"Default value is \"99\". "
|
'Default value is "99". '
|
||||||
"Use \"--percentile-metrics\" to select metrics.",
|
'Use "--percentile-metrics" to select metrics.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--goodput",
|
"--goodput",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
required=False,
|
required=False,
|
||||||
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
help='Specify service level objectives for goodput as "KEY:VALUE" '
|
||||||
"pairs, where the key is a metric name, and the value is in "
|
"pairs, where the key is a metric name, and the value is in "
|
||||||
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
|
||||||
"separated by spaces. Allowed request level metric names are "
|
"separated by spaces. Allowed request level metric names are "
|
||||||
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
'"ttft", "tpot", "e2el". For more context on the definition of '
|
||||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||||
|
)
|
||||||
|
|
||||||
# group for dataset specific arguments
|
# group for dataset specific arguments
|
||||||
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
sonnet_group = parser.add_argument_group("sonnet dataset options")
|
||||||
@@ -1069,8 +1061,8 @@ if __name__ == "__main__":
|
|||||||
"--sharegpt-output-len",
|
"--sharegpt-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the output length "
|
help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
|
||||||
"from the ShareGPT dataset.")
|
)
|
||||||
|
|
||||||
random_group = parser.add_argument_group("random dataset options")
|
random_group = parser.add_argument_group("random dataset options")
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
@@ -1098,29 +1090,24 @@ if __name__ == "__main__":
|
|||||||
"--random-prefix-len",
|
"--random-prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help=("Number of fixed prefix tokens before the random context "
|
help=(
|
||||||
"in a request. "
|
"Number of fixed prefix tokens before the random context "
|
||||||
"The total input length is the sum of `random-prefix-len` and "
|
"in a request. "
|
||||||
"a random "
|
"The total input length is the sum of `random-prefix-len` and "
|
||||||
"context length sampled from [input_len * (1 - range_ratio), "
|
"a random "
|
||||||
"input_len * (1 + range_ratio)]."),
|
"context length sampled from [input_len * (1 - range_ratio), "
|
||||||
|
"input_len * (1 + range_ratio)]."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
hf_group = parser.add_argument_group("hf dataset options")
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
hf_group.add_argument("--hf-subset",
|
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
|
||||||
type=str,
|
hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
|
||||||
default=None,
|
|
||||||
help="Subset of the HF dataset.")
|
|
||||||
hf_group.add_argument("--hf-split",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Split of the HF dataset.")
|
|
||||||
hf_group.add_argument(
|
hf_group.add_argument(
|
||||||
"--hf-output-len",
|
"--hf-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the output lengths "
|
help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
|
||||||
"from the sampled HF dataset.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_group = parser.add_argument_group("sampling parameters")
|
sampling_group = parser.add_argument_group("sampling parameters")
|
||||||
@@ -1128,52 +1115,58 @@ if __name__ == "__main__":
|
|||||||
"--top-p",
|
"--top-p",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
|
||||||
"backends.")
|
)
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--top-k",
|
"--top-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
|
||||||
"backends.")
|
)
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--min-p",
|
"--min-p",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
|
||||||
"backends.")
|
)
|
||||||
sampling_group.add_argument(
|
sampling_group.add_argument(
|
||||||
"--temperature",
|
"--temperature",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Temperature sampling parameter. Only has effect on "
|
help="Temperature sampling parameter. Only has effect on "
|
||||||
"openai-compatible backends. If not specified, default to greedy "
|
"openai-compatible backends. If not specified, default to greedy "
|
||||||
"decoding (i.e. temperature==0.0).")
|
"decoding (i.e. temperature==0.0).",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer-mode',
|
"--tokenizer-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
choices=["auto", "slow", "mistral", "custom"],
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
'always use the slow tokenizer. \n* '
|
"always use the slow tokenizer. \n* "
|
||||||
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||||
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
'"custom" will use --tokenizer to select the preregistered tokenizer.',
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--served-model-name",
|
parser.add_argument(
|
||||||
type=str,
|
"--served-model-name",
|
||||||
default=None,
|
type=str,
|
||||||
help="The model name used in the API. "
|
default=None,
|
||||||
"If not specified, the model name will be the "
|
help="The model name used in the API. "
|
||||||
"same as the ``--model`` argument. ")
|
"If not specified, the model name will be the "
|
||||||
|
"same as the ``--model`` argument. ",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--lora-modules",
|
parser.add_argument(
|
||||||
nargs='+',
|
"--lora-modules",
|
||||||
default=None,
|
nargs="+",
|
||||||
help="A subset of LoRA module names passed in when "
|
default=None,
|
||||||
"launching the server. For each request, the "
|
help="A subset of LoRA module names passed in when "
|
||||||
"script chooses a LoRA module at random.")
|
"launching the server. For each request, the "
|
||||||
|
"script chooses a LoRA module at random.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@@ -7,4 +7,4 @@ tensor_parallel_size: 1
|
|||||||
enable_chunked_prefill: True
|
enable_chunked_prefill: True
|
||||||
max_num_batched_tokens: 384
|
max_num_batched_tokens: 384
|
||||||
quantization: wint4
|
quantization: wint4
|
||||||
reasoning_parser: ernie-45-vl
|
reasoning_parser: ernie-45-vl
|
||||||
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
|||||||
pd_comm_port: "2334"
|
pd_comm_port: "2334"
|
||||||
max_num_batched_tokens: 384
|
max_num_batched_tokens: 384
|
||||||
max_num_partial_prefills: 3
|
max_num_partial_prefills: 3
|
||||||
max_long_partial_prefills: 3
|
max_long_partial_prefills: 3
|
||||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
|||||||
engine_worker_queue_port: 6677
|
engine_worker_queue_port: 6677
|
||||||
cache_transfer_protocol: "rdma,ipc"
|
cache_transfer_protocol: "rdma,ipc"
|
||||||
rdma_comm_ports: "7675,7676,7677,7678"
|
rdma_comm_ports: "7675,7676,7677,7678"
|
||||||
pd_comm_port: "2333"
|
pd_comm_port: "2333"
|
||||||
|
@@ -10,4 +10,4 @@ engine_worker_queue_port: 6677
|
|||||||
num_gpu_blocks_override: 1024
|
num_gpu_blocks_override: 1024
|
||||||
cache_transfer_protocol: "rdma"
|
cache_transfer_protocol: "rdma"
|
||||||
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
|
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
|
||||||
pd_comm_port: "2334"
|
pd_comm_port: "2334"
|
||||||
|
@@ -10,4 +10,4 @@ splitwise_role: decode
|
|||||||
engine_worker_queue_port: 6678
|
engine_worker_queue_port: 6678
|
||||||
cache_transfer_protocol: "rdma,ipc"
|
cache_transfer_protocol: "rdma,ipc"
|
||||||
rdma_comm_ports: "7671,7672,7673,7674"
|
rdma_comm_ports: "7671,7672,7673,7674"
|
||||||
pd_comm_port: "2334"
|
pd_comm_port: "2334"
|
||||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
|||||||
engine_worker_queue_port: 6677
|
engine_worker_queue_port: 6677
|
||||||
cache_transfer_protocol: "rdma,ipc"
|
cache_transfer_protocol: "rdma,ipc"
|
||||||
rdma_comm_ports: "7675,7676,7677,7678"
|
rdma_comm_ports: "7675,7676,7677,7678"
|
||||||
pd_comm_port: "2333"
|
pd_comm_port: "2333"
|
||||||
|
@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
|
|||||||
pd_comm_port: "2334"
|
pd_comm_port: "2334"
|
||||||
max_num_batched_tokens: 384
|
max_num_batched_tokens: 384
|
||||||
max_num_partial_prefills: 3
|
max_num_partial_prefills: 3
|
||||||
max_long_partial_prefills: 3
|
max_long_partial_prefills: 3
|
||||||
|
@@ -9,4 +9,4 @@ cache_queue_port: 55664
|
|||||||
engine_worker_queue_port: 6677
|
engine_worker_queue_port: 6677
|
||||||
cache_transfer_protocol: "rdma,ipc"
|
cache_transfer_protocol: "rdma,ipc"
|
||||||
rdma_comm_ports: "7675,7676,7677,7678"
|
rdma_comm_ports: "7675,7676,7677,7678"
|
||||||
pd_comm_port: "2333"
|
pd_comm_port: "2333"
|
||||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 75
|
|||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.85
|
||||||
kv_cache_ratio: 0.75
|
kv_cache_ratio: 0.75
|
||||||
quantization: wint4
|
quantization: wint4
|
||||||
tensor_parallel_size: 4
|
tensor_parallel_size: 4
|
||||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 25
|
|||||||
gpu_memory_utilization: 0.9
|
gpu_memory_utilization: 0.9
|
||||||
kv_cache_ratio: 0.75
|
kv_cache_ratio: 0.75
|
||||||
quantization: wint8
|
quantization: wint8
|
||||||
tensor_parallel_size: 4
|
tensor_parallel_size: 4
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
metadata:
|
metadata:
|
||||||
min_tokens: 32
|
min_tokens: 32
|
||||||
max_tokens: 33
|
max_tokens: 33
|
||||||
|
@@ -5,4 +5,4 @@ metadata:
|
|||||||
max_tokens: 12288
|
max_tokens: 12288
|
||||||
repetition_penalty: 1.05
|
repetition_penalty: 1.05
|
||||||
frequency_penalty: 0
|
frequency_penalty: 0
|
||||||
presence_penalty: 0
|
presence_penalty: 0
|
||||||
|
@@ -5,4 +5,4 @@ metadata:
|
|||||||
max_tokens: 12288
|
max_tokens: 12288
|
||||||
repetition_penalty: 1.0
|
repetition_penalty: 1.0
|
||||||
frequency_penalty: 0
|
frequency_penalty: 0
|
||||||
presence_penalty: 1.5
|
presence_penalty: 1.5
|
||||||
|
@@ -8,4 +8,4 @@ frequency_penalty: 0
|
|||||||
presence_penalty: 0
|
presence_penalty: 0
|
||||||
skip_special_tokens: false
|
skip_special_tokens: false
|
||||||
chat_template_kwargs:
|
chat_template_kwargs:
|
||||||
enable_thinking: true
|
enable_thinking: true
|
||||||
|
@@ -3,4 +3,4 @@ max_num_seqs: 64
|
|||||||
gpu_memory_utilization: 0.9
|
gpu_memory_utilization: 0.9
|
||||||
tensor_parallel_size: 8
|
tensor_parallel_size: 8
|
||||||
quantization: wint8
|
quantization: wint8
|
||||||
reasoning_parser: ernie-x1
|
reasoning_parser: ernie-x1
|
||||||
|
@@ -40,4 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out);
|
||||||
|
@@ -216,7 +216,7 @@ __global__ void append_dequant_cache_kv_c8(
|
|||||||
|
|
||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
|
||||||
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||||
|
|
||||||
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
|
||||||
tid % 8 * num_elems_per_128b<CacheT>();
|
tid % 8 * num_elems_per_128b<CacheT>();
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ __global__ void append_dequant_cache_kv_c8(
|
|||||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
|
||||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
|
||||||
|
|
||||||
|
|
||||||
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
|
||||||
#ifdef C8_DEBUG
|
#ifdef C8_DEBUG
|
||||||
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
|
||||||
@@ -373,14 +373,14 @@ void AppendDequantCache(
|
|||||||
paddle::Tensor *k_out,
|
paddle::Tensor *k_out,
|
||||||
paddle::Tensor *v_out,
|
paddle::Tensor *v_out,
|
||||||
const cudaStream_t& stream
|
const cudaStream_t& stream
|
||||||
) {
|
) {
|
||||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||||
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
int block_num = cache_num_blocks_x.data<int>()[0];
|
int block_num = cache_num_blocks_x.data<int>()[0];
|
||||||
dim3 grids(block_num, 1, kv_num_heads);
|
dim3 grids(block_num, 1, kv_num_heads);
|
||||||
dim3 blocks(32, NUM_WARPS);
|
dim3 blocks(32, NUM_WARPS);
|
||||||
|
|
||||||
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||||
|
|
||||||
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;
|
||||||
|
@@ -41,7 +41,7 @@ __global__ void append_clear_cache_int8_block(
|
|||||||
const int wid = tid / 32;
|
const int wid = tid / 32;
|
||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -115,7 +115,7 @@ __global__ void append_clear_cache_int4_block(
|
|||||||
const int wid = tid / 32;
|
const int wid = tid / 32;
|
||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -484,7 +484,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
const int wid = tid / 32;
|
const int wid = tid / 32;
|
||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -716,7 +716,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
const int wid = tid / 32;
|
const int wid = tid / 32;
|
||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -1097,7 +1097,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
|
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -1403,7 +1403,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
const int lane_id = tid % 32;
|
const int lane_id = tid % 32;
|
||||||
|
|
||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
@@ -1792,4 +1792,4 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
|
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -582,4 +582,4 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out);
|
||||||
|
@@ -39,4 +39,4 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out);
|
||||||
|
@@ -30,4 +30,4 @@ inline int getSMVersion()
|
|||||||
return sm_major * 10 + sm_minor;
|
return sm_major * 10 + sm_minor;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -136,4 +136,4 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
|
|||||||
ElementAccumulator, DefaultScaleMode>;
|
ElementAccumulator, DefaultScaleMode>;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace cutlass_extensions
|
} // namespace cutlass_extensions
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -54,7 +54,7 @@
|
|||||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||||
/// from the tensor core accumulators to the main accumulators when the number
|
/// from the tensor core accumulators to the main accumulators when the number
|
||||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||||
/// the tensor core accumulators are zeroed.
|
/// the tensor core accumulators are zeroed.
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -64,7 +64,7 @@ namespace cutlass::gemm::collective {
|
|||||||
template <
|
template <
|
||||||
class EngineAccum,
|
class EngineAccum,
|
||||||
class LayoutAccum>
|
class LayoutAccum>
|
||||||
struct GmmaFP8AccumulationWithScale {
|
struct GmmaFP8AccumulationWithScale {
|
||||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||||
using ElementAccumulator = typename EngineAccum::value_type;
|
using ElementAccumulator = typename EngineAccum::value_type;
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ private:
|
|||||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||||
uint32_t mma_count_; // current executed MMAs
|
uint32_t mma_count_; // current executed MMAs
|
||||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||||
|
|
||||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
@@ -116,11 +116,11 @@ public:
|
|||||||
TensorAccum &accum,
|
TensorAccum &accum,
|
||||||
uint32_t accum_promotion_interval,
|
uint32_t accum_promotion_interval,
|
||||||
uint32_t mma_count_per_mainloop_iteration)
|
uint32_t mma_count_per_mainloop_iteration)
|
||||||
: accum_(accum),
|
: accum_(accum),
|
||||||
accum_promotion_interval_(accum_promotion_interval),
|
accum_promotion_interval_(accum_promotion_interval),
|
||||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||||
mma_count_(0),
|
mma_count_(0),
|
||||||
reset_accum_flag_(0)
|
reset_accum_flag_(0)
|
||||||
{
|
{
|
||||||
accum_temp_ = cute::make_fragment_like(accum);
|
accum_temp_ = cute::make_fragment_like(accum);
|
||||||
}
|
}
|
||||||
@@ -129,14 +129,14 @@ public:
|
|||||||
// Methods (Common)
|
// Methods (Common)
|
||||||
//
|
//
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
TensorAccum& operator()() {
|
TensorAccum& operator()() {
|
||||||
return accum_temp_;
|
return accum_temp_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
bool prepare_if_needed() {
|
bool prepare_if_needed() {
|
||||||
return reset_accum_flag_;
|
return reset_accum_flag_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -137,7 +137,7 @@ struct CollectiveMma<
|
|||||||
using PipelineParams = typename MainloopPipeline::Params;
|
using PipelineParams = typename MainloopPipeline::Params;
|
||||||
|
|
||||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||||
static constexpr int NumProducerThreadEvents = 33;
|
static constexpr int NumProducerThreadEvents = 33;
|
||||||
|
|
||||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||||
@@ -161,11 +161,11 @@ struct CollectiveMma<
|
|||||||
SmemLayoutAtomB{},
|
SmemLayoutAtomB{},
|
||||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||||
|
|
||||||
// Block scaling gmem-to-smem copy atom
|
// Block scaling gmem-to-smem copy atom
|
||||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||||
|
|
||||||
// Block scaling smem layout
|
// Block scaling smem layout
|
||||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||||
@@ -202,7 +202,7 @@ struct CollectiveMma<
|
|||||||
StrideA dA;
|
StrideA dA;
|
||||||
ElementB const* ptr_B;
|
ElementB const* ptr_B;
|
||||||
StrideB dB;
|
StrideB dB;
|
||||||
ElementBlockScale const* ptr_scale_A;
|
ElementBlockScale const* ptr_scale_A;
|
||||||
ElementBlockScale const* ptr_scale_B;
|
ElementBlockScale const* ptr_scale_B;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -228,7 +228,7 @@ struct CollectiveMma<
|
|||||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||||
// Block scaling factors for A and B
|
// Block scaling factors for A and B
|
||||||
ElementBlockScale const* ptr_scale_A;
|
ElementBlockScale const* ptr_scale_A;
|
||||||
ElementBlockScale const* ptr_scale_B;
|
ElementBlockScale const* ptr_scale_B;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ struct CollectiveMma<
|
|||||||
constexpr int tma_alignment_bits = 128;
|
constexpr int tma_alignment_bits = 128;
|
||||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||||
auto [M,N,K,L] = problem_shape_MNKL;
|
auto [M,N,K,L] = problem_shape_MNKL;
|
||||||
|
|
||||||
bool implementable = true;
|
bool implementable = true;
|
||||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||||
@@ -346,7 +346,7 @@ struct CollectiveMma<
|
|||||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||||
|
|
||||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||||
@@ -406,26 +406,26 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||||
|
|
||||||
Tensor gScaleA = local_tile(
|
Tensor gScaleA = local_tile(
|
||||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||||
Tensor cScaleA = local_tile(
|
Tensor cScaleA = local_tile(
|
||||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||||
make_coord(m_coord,_,l_coord));
|
make_coord(m_coord,_,l_coord));
|
||||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||||
|
|
||||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||||
|
|
||||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||||
|
|
||||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||||
|
|
||||||
@@ -455,7 +455,7 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||||
// all scales are valid, since we could have a partial tiles along M)
|
// all scales are valid, since we could have a partial tiles along M)
|
||||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -536,7 +536,7 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||||
|
|
||||||
// Block scaling
|
// Block scaling
|
||||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||||
Layout<
|
Layout<
|
||||||
@@ -548,17 +548,17 @@ struct CollectiveMma<
|
|||||||
//
|
//
|
||||||
// Define C accumulators and A/B partitioning
|
// Define C accumulators and A/B partitioning
|
||||||
//
|
//
|
||||||
|
|
||||||
// Layout of warp group to thread mapping
|
// Layout of warp group to thread mapping
|
||||||
|
|
||||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||||
|
|
||||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||||
Int<NumThreadsPerWarpGroup>{});
|
Int<NumThreadsPerWarpGroup>{});
|
||||||
|
|
||||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||||
@@ -590,7 +590,7 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||||
PipelineState smem_pipe_release = smem_pipe_read;
|
PipelineState smem_pipe_release = smem_pipe_read;
|
||||||
|
|
||||||
// Per block scale values for operand A and B
|
// Per block scale values for operand A and B
|
||||||
|
|
||||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||||
@@ -618,7 +618,7 @@ struct CollectiveMma<
|
|||||||
}
|
}
|
||||||
|
|
||||||
int read_stage = smem_pipe_read.index();
|
int read_stage = smem_pipe_read.index();
|
||||||
|
|
||||||
// Load per block scale values from shared memory to registers.
|
// Load per block scale values from shared memory to registers.
|
||||||
scale_b = sScaleB[read_stage];
|
scale_b = sScaleB[read_stage];
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@@ -668,7 +668,7 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
int read_stage = smem_pipe_read.index();
|
int read_stage = smem_pipe_read.index();
|
||||||
|
|
||||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||||
scale_b = sScaleB[read_stage];
|
scale_b = sScaleB[read_stage];
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||||
@@ -712,7 +712,7 @@ struct CollectiveMma<
|
|||||||
++smem_pipe_read;
|
++smem_pipe_read;
|
||||||
++smem_pipe_release;
|
++smem_pipe_release;
|
||||||
}
|
}
|
||||||
|
|
||||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||||
|
|
||||||
warpgroup_fence_operand(accumulation());
|
warpgroup_fence_operand(accumulation());
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -50,4 +50,4 @@ struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
|||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace cutlass::gemm
|
} // namespace cutlass::gemm
|
||||||
|
@@ -90,4 +90,4 @@ struct GemmMoeProblemVisitor
|
|||||||
} // namespace gemm
|
} // namespace gemm
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@@ -90,7 +90,7 @@ template <
|
|||||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||||
/// Used for partial specialization
|
/// Used for partial specialization
|
||||||
typename Enable = bool>
|
typename Enable = bool>
|
||||||
class Wint2xMmaMultistage :
|
class Wint2xMmaMultistage :
|
||||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||||
public:
|
public:
|
||||||
///< Base class
|
///< Base class
|
||||||
|
@@ -57,7 +57,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
|
|||||||
hasbias,
|
hasbias,
|
||||||
ElementD,
|
ElementD,
|
||||||
void>;
|
void>;
|
||||||
|
|
||||||
constexpr int ScaleMsPerTile = size<0>(TileShape{});
|
constexpr int ScaleMsPerTile = size<0>(TileShape{});
|
||||||
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
|
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
|
||||||
|
|
||||||
@@ -161,7 +161,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
|
|||||||
arguments.scheduler.decomposition_mode = DecompositionMode::StreamK;
|
arguments.scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||||
arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Gemm gemm_op;
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
@@ -170,4 +170,4 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@@ -148,4 +148,4 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@@ -54,7 +54,7 @@ public:
|
|||||||
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
||||||
|
|
||||||
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int k) const = 0;
|
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int k) const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static constexpr int SPLIT_K_LIMIT = 7;
|
static constexpr int SPLIT_K_LIMIT = 7;
|
||||||
static constexpr int MIN_M_TILE = 16;
|
static constexpr int MIN_M_TILE = 16;
|
||||||
|
@@ -93,8 +93,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(extract_text_token_output)
|
PD_BUILD_STATIC_OP(extract_text_token_output)
|
||||||
.Inputs({"max_seq_len",
|
.Inputs({"max_seq_len",
|
||||||
"max_seq_len_index",
|
"max_seq_len_index",
|
||||||
"mm_token_num_len",
|
"mm_token_num_len",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"score_text"})
|
"score_text"})
|
||||||
|
@@ -105,7 +105,7 @@ __global__ void cudaCoreGemm(InputType const* __restrict__ act,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
|
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
|
||||||
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
|
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
|
||||||
@@ -188,4 +188,4 @@ bool cuda_core_gemm_launcher(GemmParams const& params) {
|
|||||||
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
|
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
|
||||||
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
|
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
|
||||||
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
|
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
|
||||||
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
|
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
|
||||||
|
@@ -61,7 +61,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
|
|||||||
st_idx += cur_st_len;
|
st_idx += cur_st_len;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
while (idx < seq_lens_origin) {
|
while (idx < seq_lens_origin) {
|
||||||
idx = idx + split_fuse_text_size;
|
idx = idx + split_fuse_text_size;
|
||||||
if (idx >= seq_lens_origin) {
|
if (idx >= seq_lens_origin) {
|
||||||
@@ -116,7 +116,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
|
|||||||
while (ib < img_total && cur_img_len < chunk_image_token_number) {
|
while (ib < img_total && cur_img_len < chunk_image_token_number) {
|
||||||
int token_times = 4;
|
int token_times = 4;
|
||||||
cur_img_len += (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times;
|
cur_img_len += (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times;
|
||||||
ib ++;
|
ib ++;
|
||||||
chunk_image_number ++;
|
chunk_image_number ++;
|
||||||
}
|
}
|
||||||
image_chunk_selections_vector.emplace_back(chunk_image_number);
|
image_chunk_selections_vector.emplace_back(chunk_image_number);
|
||||||
|
@@ -88,7 +88,7 @@ void sent_key_value_by_remote_ptr(
|
|||||||
#ifdef DEBUG_IPC_SENT
|
#ifdef DEBUG_IPC_SENT
|
||||||
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
|
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
|
||||||
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
|
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
|
||||||
<<" local_device_id:" << local_device_id
|
<<" local_device_id:" << local_device_id
|
||||||
<<" remote_device_id:" << remote_device_id
|
<<" remote_device_id:" << remote_device_id
|
||||||
<<" block_idx_stride:" << block_idx_stride
|
<<" block_idx_stride:" << block_idx_stride
|
||||||
<<" block_size_byte:" << block_size_byte
|
<<" block_size_byte:" << block_size_byte
|
||||||
@@ -107,25 +107,25 @@ void sent_key_value_by_remote_ptr(
|
|||||||
#endif
|
#endif
|
||||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
cudaMemcpyPeerAsync(
|
cudaMemcpyPeerAsync(
|
||||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||||
remote_device_id,
|
remote_device_id,
|
||||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||||
local_device_id,
|
local_device_id,
|
||||||
block_size_byte,
|
block_size_byte,
|
||||||
stream);
|
stream);
|
||||||
#endif
|
#endif
|
||||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
cudaMemcpyPeer(
|
cudaMemcpyPeer(
|
||||||
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
|
||||||
remote_device_id,
|
remote_device_id,
|
||||||
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
|
||||||
local_device_id,
|
local_device_id,
|
||||||
block_size_byte);
|
block_size_byte);
|
||||||
#endif
|
#endif
|
||||||
cudaError_t err = cudaGetLastError();
|
cudaError_t err = cudaGetLastError();
|
||||||
if ( err != cudaSuccess )
|
if ( err != cudaSuccess )
|
||||||
{
|
{
|
||||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
@@ -140,7 +140,7 @@ void sent_key_value_by_remote_ptr(
|
|||||||
#ifdef DEBUG_IPC_SENT
|
#ifdef DEBUG_IPC_SENT
|
||||||
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
|
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
|
||||||
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
|
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
|
||||||
<<" local_device_id:" << local_device_id
|
<<" local_device_id:" << local_device_id
|
||||||
<<" remote_device_id:" << remote_device_id
|
<<" remote_device_id:" << remote_device_id
|
||||||
<<" block_idx_stride:" << block_idx_stride
|
<<" block_idx_stride:" << block_idx_stride
|
||||||
<<" block_size_byte:" << block_size_byte
|
<<" block_size_byte:" << block_size_byte
|
||||||
@@ -159,26 +159,26 @@ void sent_key_value_by_remote_ptr(
|
|||||||
#endif
|
#endif
|
||||||
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
cudaMemcpyPeerAsync(
|
cudaMemcpyPeerAsync(
|
||||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||||
remote_device_id,
|
remote_device_id,
|
||||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||||
local_device_id,
|
local_device_id,
|
||||||
block_size_byte,
|
block_size_byte,
|
||||||
stream);
|
stream);
|
||||||
#endif
|
#endif
|
||||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
cudaMemcpyPeer(
|
cudaMemcpyPeer(
|
||||||
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
|
||||||
remote_device_id,
|
remote_device_id,
|
||||||
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
|
||||||
local_device_id,
|
local_device_id,
|
||||||
block_size_byte);
|
block_size_byte);
|
||||||
cudaDeviceSynchronize();
|
cudaDeviceSynchronize();
|
||||||
#endif
|
#endif
|
||||||
err = cudaGetLastError();
|
err = cudaGetLastError();
|
||||||
if ( err != cudaSuccess )
|
if ( err != cudaSuccess )
|
||||||
{
|
{
|
||||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||||
}
|
}
|
||||||
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
|
||||||
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
|
||||||
@@ -316,11 +316,11 @@ void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
|
|||||||
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
|
||||||
cudaStreamSynchronize(cuda_stream);
|
cudaStreamSynchronize(cuda_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
|
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
|
||||||
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
|
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
|
||||||
.Attrs({ "block_num: int",
|
.Attrs({ "block_num: int",
|
||||||
"local_device_id: int",
|
"local_device_id: int",
|
||||||
"remote_device_id: int",
|
"remote_device_id: int",
|
||||||
"cuda_stream_raw: int64_t"})
|
"cuda_stream_raw: int64_t"})
|
||||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||||
@@ -332,4 +332,4 @@ PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
|
|||||||
.Attrs({"cuda_stream_raw: int64_t"})
|
.Attrs({"cuda_stream_raw: int64_t"})
|
||||||
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
|
||||||
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
|
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
|
||||||
|
@@ -57,5 +57,3 @@ paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
|||||||
num_experts);
|
num_experts);
|
||||||
return token_nums_per_expert;
|
return token_nums_per_expert;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -737,7 +737,7 @@ void MoeFastHardamardWrapper(const T *x_data,
|
|||||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||||
|
|
||||||
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
||||||
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
||||||
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
||||||
constexpr int kThreads = 128;
|
constexpr int kThreads = 128;
|
||||||
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
||||||
|
@@ -124,4 +124,4 @@ class CubKeyValueSorter {
|
|||||||
int num_bits_;
|
int num_bits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace phi
|
} // namespace phi
|
||||||
|
@@ -360,10 +360,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
|||||||
normalizing_factor = 1.f / Z;
|
normalizing_factor = 1.f / Z;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
T val = T(threadDataExp * normalizing_factor);
|
T val = T(threadDataExp * normalizing_factor);
|
||||||
|
|
||||||
// top_k
|
// top_k
|
||||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||||
@@ -374,10 +374,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
|||||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
thread_kvp.key = 0;
|
thread_kvp.key = 0;
|
||||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||||
|
|
||||||
if (threadIdx.x < num_experts) {
|
if (threadIdx.x < num_experts) {
|
||||||
cub_kvp inp_kvp;
|
cub_kvp inp_kvp;
|
||||||
int expert = threadIdx.x;
|
int expert = threadIdx.x;
|
||||||
inp_kvp.key = expert;
|
inp_kvp.key = expert;
|
||||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||||
|
|
||||||
@@ -518,12 +518,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
normalizing_factor = 1.f / Z;
|
normalizing_factor = 1.f / Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
T val = T(threadDataExp * normalizing_factor);
|
T val = T(threadDataExp * normalizing_factor);
|
||||||
|
|
||||||
// top_k
|
// top_k
|
||||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||||
@@ -541,7 +541,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
|||||||
|
|
||||||
if (threadIdx.x < num_experts) {
|
if (threadIdx.x < num_experts) {
|
||||||
cub_kvp inp_kvp;
|
cub_kvp inp_kvp;
|
||||||
int expert = threadIdx.x;
|
int expert = threadIdx.x;
|
||||||
inp_kvp.key = expert;
|
inp_kvp.key = expert;
|
||||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||||
|
|
||||||
@@ -1065,7 +1065,7 @@ __global__ void initialize_moe_routing_kernel(
|
|||||||
const T* unpermuted_input,
|
const T* unpermuted_input,
|
||||||
OutT* permuted_output,
|
OutT* permuted_output,
|
||||||
const int* expanded_dest_row_to_expanded_source_row,
|
const int* expanded_dest_row_to_expanded_source_row,
|
||||||
const int *expert_idx_per_token,
|
const int *expert_idx_per_token,
|
||||||
const float *w4a8_in_scale,
|
const float *w4a8_in_scale,
|
||||||
int* expanded_source_row_to_expanded_dest_row,
|
int* expanded_source_row_to_expanded_dest_row,
|
||||||
const int64_t num_rows,
|
const int64_t num_rows,
|
||||||
@@ -1088,7 +1088,7 @@ __global__ void initialize_moe_routing_kernel(
|
|||||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||||
expanded_dest_row;
|
expanded_dest_row;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (expanded_dest_row < active_rows) {
|
if (expanded_dest_row < active_rows) {
|
||||||
|
|
||||||
const int expert_idx = expert_idx_per_token[expanded_dest_row];
|
const int expert_idx = expert_idx_per_token[expanded_dest_row];
|
||||||
@@ -1130,7 +1130,7 @@ static void run(
|
|||||||
const T* unpermuted_input,
|
const T* unpermuted_input,
|
||||||
OutT* permuted_output,
|
OutT* permuted_output,
|
||||||
const int* expanded_dest_row_to_expanded_source_row,
|
const int* expanded_dest_row_to_expanded_source_row,
|
||||||
const int *expert_idx_per_token,
|
const int *expert_idx_per_token,
|
||||||
const float *w4a8_in_scale,
|
const float *w4a8_in_scale,
|
||||||
int* expanded_source_row_to_expanded_dest_row,
|
int* expanded_source_row_to_expanded_dest_row,
|
||||||
const int64_t num_rows,
|
const int64_t num_rows,
|
||||||
|
@@ -17,7 +17,7 @@
|
|||||||
// topk warps
|
// topk warps
|
||||||
template<typename T, int VecSize>
|
template<typename T, int VecSize>
|
||||||
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
|
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
|
||||||
|
|
||||||
AlignedVector<T, VecSize> in_vec;
|
AlignedVector<T, VecSize> in_vec;
|
||||||
|
|
||||||
const int bid = blockIdx.x;
|
const int bid = blockIdx.x;
|
||||||
@@ -32,7 +32,7 @@ __global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int
|
|||||||
}
|
}
|
||||||
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
|
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
|
||||||
|
|
||||||
|
|
||||||
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
|
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
|
||||||
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
|
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
|
||||||
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
|
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
|
||||||
@@ -81,7 +81,7 @@ std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
|
|||||||
permute_indices_per_token.data<int32_t>(),
|
permute_indices_per_token.data<int32_t>(),
|
||||||
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
|
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
|
||||||
topk_idx.data<int64_t>(),
|
topk_idx.data<int64_t>(),
|
||||||
token_num, topk, num_vecs,
|
token_num, topk, num_vecs,
|
||||||
hidden, max_tokens_per_expert
|
hidden, max_tokens_per_expert
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -112,4 +112,4 @@ PD_BUILD_STATIC_OP(moe_deepgemm_permute)
|
|||||||
.Inputs({"x", "topk_idx"})
|
.Inputs({"x", "topk_idx"})
|
||||||
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
|
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
|
||||||
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
|
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
|
||||||
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));
|
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));
|
||||||
|
@@ -232,12 +232,12 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Mixture of Experts (MoE) Expert Dispatch Operator
|
* @brief Mixture of Experts (MoE) Expert Dispatch Operator
|
||||||
*
|
*
|
||||||
* This operator performs the following key functions:
|
* This operator performs the following key functions:
|
||||||
* 1. Computes top-k experts for each input token based on gating scores
|
* 1. Computes top-k experts for each input token based on gating scores
|
||||||
* 2. Permutes input tokens according to their selected experts for efficient expert processing
|
* 2. Permutes input tokens according to their selected experts for efficient expert processing
|
||||||
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
|
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
|
||||||
*
|
*
|
||||||
* Inputs:
|
* Inputs:
|
||||||
* - input: The input tensor to be routed to experts
|
* - input: The input tensor to be routed to experts
|
||||||
* Shape: [total_tokens, hidden_size]
|
* Shape: [total_tokens, hidden_size]
|
||||||
@@ -246,7 +246,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
|||||||
* Shape: [total_tokens, expert_num]
|
* Shape: [total_tokens, expert_num]
|
||||||
* dtype: must be float32
|
* dtype: must be float32
|
||||||
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
|
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
|
||||||
*
|
*
|
||||||
* Outputs:
|
* Outputs:
|
||||||
* - permute_input: Permuted input tensor organized by expert
|
* - permute_input: Permuted input tensor organized by expert
|
||||||
* Shape: [moe_topk * total_tokens, hidden_size]
|
* Shape: [moe_topk * total_tokens, hidden_size]
|
||||||
@@ -263,7 +263,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
|||||||
* - top_k_indices: Indices of selected top-k experts for each token
|
* - top_k_indices: Indices of selected top-k experts for each token
|
||||||
* Shape: [total_tokens, moe_topk]
|
* Shape: [total_tokens, moe_topk]
|
||||||
* dtype: int32
|
* dtype: int32
|
||||||
*
|
*
|
||||||
* Attributes:
|
* Attributes:
|
||||||
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
|
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
|
||||||
* - group_moe: Whether to perform group softmax within the operator
|
* - group_moe: Whether to perform group softmax within the operator
|
||||||
@@ -272,7 +272,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
|||||||
* - topk_only_mode: Operation mode selector
|
* - topk_only_mode: Operation mode selector
|
||||||
* (true: only performs topk selection without softmax,
|
* (true: only performs topk selection without softmax,
|
||||||
* false: performs full softmax+topk computation)
|
* false: performs full softmax+topk computation)
|
||||||
*
|
*
|
||||||
* Note:
|
* Note:
|
||||||
* - The operator requires 2D input format [total_tokens, hidden_size]
|
* - The operator requires 2D input format [total_tokens, hidden_size]
|
||||||
* - For optimal performance, expert_num should be a power of 2 when possible
|
* - For optimal performance, expert_num should be a power of 2 when possible
|
||||||
@@ -283,7 +283,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
|||||||
paddle::Optional("gating_correction_bias"),
|
paddle::Optional("gating_correction_bias"),
|
||||||
paddle::Optional("w4a8_in_scale")})
|
paddle::Optional("w4a8_in_scale")})
|
||||||
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||||
"permute_indices_per_token", "topk_weight", "topk_idx",
|
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||||
"expert_idx_per_token"})
|
"expert_idx_per_token"})
|
||||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||||
|
@@ -263,4 +263,4 @@ PD_BUILD_OP(moe_redundant_topk_select)
|
|||||||
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
|
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
|
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));
|
||||||
|
@@ -106,4 +106,4 @@ template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4,
|
|||||||
|
|
||||||
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -36,4 +36,4 @@ struct msgdata {
|
|||||||
struct msgdatakv {
|
struct msgdatakv {
|
||||||
long mtype;
|
long mtype;
|
||||||
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
|
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
|
||||||
};
|
};
|
||||||
|
@@ -14,9 +14,10 @@
|
|||||||
"""read_ids"""
|
"""read_ids"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def deserialize_from_file(fp):
|
def deserialize_from_file(fp):
|
||||||
"""deserialize from file"""
|
"""deserialize from file"""
|
||||||
|
@@ -13,9 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""read temp_ids from file"""
|
"""read temp_ids from file"""
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def deserialize_from_file(fp):
|
def deserialize_from_file(fp):
|
||||||
"""
|
"""
|
||||||
|
@@ -15,7 +15,7 @@
|
|||||||
#include "remote_cache_kv_ipc.h"
|
#include "remote_cache_kv_ipc.h"
|
||||||
|
|
||||||
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data RemoteCacheKvIpc::kv_complete_signal_meta_data;
|
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data RemoteCacheKvIpc::kv_complete_signal_meta_data;
|
||||||
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
|
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
|
||||||
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query;
|
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query;
|
||||||
void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr;
|
void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr;
|
||||||
bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
|
bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
|
||||||
@@ -118,4 +118,3 @@ void CUDART_CB RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_que
|
|||||||
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
|
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
|
||||||
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
|
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -71,7 +71,7 @@ struct RemoteCacheKvIpc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg_sed.mtext[0] = encoder_count;
|
msg_sed.mtext[0] = encoder_count;
|
||||||
|
|
||||||
if (!inited) {
|
if (!inited) {
|
||||||
// just init once
|
// just init once
|
||||||
const int msg_id = 1024 + rank;
|
const int msg_id = 1024 + rank;
|
||||||
@@ -90,7 +90,7 @@ struct RemoteCacheKvIpc {
|
|||||||
assert(layer_id_ <= num_layers_);
|
assert(layer_id_ <= num_layers_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data;
|
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data;
|
||||||
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query;
|
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query;
|
||||||
static void* kv_complete_signal_identity_ptr;
|
static void* kv_complete_signal_identity_ptr;
|
||||||
|
@@ -125,7 +125,7 @@ void group_wise_scale(ScaleT* scale,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input,
|
std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input,
|
||||||
int groupsize,
|
int groupsize,
|
||||||
std::string scale_dtype) {
|
std::string scale_dtype) {
|
||||||
auto input_cpu = input.copy_to(paddle::CPUPlace(), false);
|
auto input_cpu = input.copy_to(paddle::CPUPlace(), false);
|
||||||
@@ -139,47 +139,47 @@ std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &in
|
|||||||
if (groupsize > 0) {
|
if (groupsize > 0) {
|
||||||
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
|
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
|
||||||
group_wise_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
|
group_wise_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
|
||||||
group_wise_quant(packed_int4.data<int8_t>(),
|
group_wise_quant(packed_int4.data<int8_t>(),
|
||||||
input_cpu.data<float>(),
|
input_cpu.data<float>(),
|
||||||
scale.data<phi::dtype::bfloat16>(),
|
scale.data<phi::dtype::bfloat16>(),
|
||||||
k,
|
k,
|
||||||
n,
|
n,
|
||||||
groupsize);
|
groupsize);
|
||||||
} else {
|
} else {
|
||||||
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
|
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
|
||||||
per_channel_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f);
|
per_channel_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f);
|
||||||
per_channel_quant(packed_int4.data<int8_t>(),
|
per_channel_quant(packed_int4.data<int8_t>(),
|
||||||
input_cpu.data<float>(),
|
input_cpu.data<float>(),
|
||||||
scale.data<phi::dtype::bfloat16>(),
|
scale.data<phi::dtype::bfloat16>(),
|
||||||
k,
|
k,
|
||||||
n);
|
n);
|
||||||
}
|
}
|
||||||
} else if (scale_dtype == "float16") {
|
} else if (scale_dtype == "float16") {
|
||||||
if (groupsize > 0) {
|
if (groupsize > 0) {
|
||||||
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
|
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
|
||||||
group_wise_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
|
group_wise_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
|
||||||
group_wise_quant(packed_int4.data<int8_t>(),
|
group_wise_quant(packed_int4.data<int8_t>(),
|
||||||
input_cpu.data<float>(),
|
input_cpu.data<float>(),
|
||||||
scale.data<phi::dtype::float16>(),
|
scale.data<phi::dtype::float16>(),
|
||||||
k,
|
k,
|
||||||
n,
|
n,
|
||||||
groupsize);
|
groupsize);
|
||||||
} else {
|
} else {
|
||||||
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
|
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
|
||||||
per_channel_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f);
|
per_channel_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f);
|
||||||
per_channel_quant(packed_int4.data<int8_t>(),
|
per_channel_quant(packed_int4.data<int8_t>(),
|
||||||
input_cpu.data<float>(),
|
input_cpu.data<float>(),
|
||||||
scale.data<phi::dtype::float16>(),
|
scale.data<phi::dtype::float16>(),
|
||||||
k,
|
k,
|
||||||
n);
|
n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = paddle::full({shape[1] / 2, shape[0]}, 0, paddle::DataType::INT8, paddle::CPUPlace());
|
auto out = paddle::full({shape[1] / 2, shape[0]}, 0, paddle::DataType::INT8, paddle::CPUPlace());
|
||||||
preprocess_weights_for_mixed_gemm(
|
preprocess_weights_for_mixed_gemm(
|
||||||
out.data<int8_t>(),
|
out.data<int8_t>(),
|
||||||
packed_int4.data<int8_t>(),
|
packed_int4.data<int8_t>(),
|
||||||
{k, n},
|
{k, n},
|
||||||
kernels::cutlass_kernels::QuantType::W4_AFP8,
|
kernels::cutlass_kernels::QuantType::W4_AFP8,
|
||||||
false);
|
false);
|
||||||
return {out, scale};
|
return {out, scale};
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
@@ -27,7 +27,7 @@
|
|||||||
|
|
||||||
std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
|
std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
|
||||||
const std::string shm_name,
|
const std::string shm_name,
|
||||||
const std::vector<int>& shape) {
|
const std::vector<int>& shape) {
|
||||||
volatile shmStruct *shm = NULL;
|
volatile shmStruct *shm = NULL;
|
||||||
sharedMemoryInfo info;
|
sharedMemoryInfo info;
|
||||||
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
|
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
|
||||||
@@ -62,4 +62,4 @@ PD_BUILD_STATIC_OP(share_external_data)
|
|||||||
.Inputs({"input"})
|
.Inputs({"input"})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.Attrs({"shm_name: std::string", "shape: std::vector<int>"})
|
.Attrs({"shm_name: std::string", "shape: std::vector<int>"})
|
||||||
.SetKernelFn(PD_KERNEL(ShareExternalData));
|
.SetKernelFn(PD_KERNEL(ShareExternalData));
|
||||||
|
@@ -19,7 +19,7 @@
|
|||||||
// #define DEBUG_EAGLE_KERNEL
|
// #define DEBUG_EAGLE_KERNEL
|
||||||
|
|
||||||
__global__ void ComputeOrderKernel(
|
__global__ void ComputeOrderKernel(
|
||||||
const int* seq_lens_this_time,
|
const int* seq_lens_this_time,
|
||||||
const int* seq_lens_encoder,
|
const int* seq_lens_encoder,
|
||||||
const int* base_model_seq_lens_this_time,
|
const int* base_model_seq_lens_this_time,
|
||||||
const int* base_model_seq_lens_encoder,
|
const int* base_model_seq_lens_encoder,
|
||||||
@@ -47,7 +47,7 @@ __global__ void ComputeOrderKernel(
|
|||||||
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
|
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
|
||||||
#endif
|
#endif
|
||||||
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
for (int j = 0; j < cur_seq_lens_encoder; j++) {
|
||||||
position_map[in_offset++] = out_offset++;
|
position_map[in_offset++] = out_offset++;
|
||||||
}
|
}
|
||||||
// 2. base model encoder. Base step=0
|
// 2. base model encoder. Base step=0
|
||||||
} else if (cur_base_model_seq_lens_encoder != 0) {
|
} else if (cur_base_model_seq_lens_encoder != 0) {
|
||||||
@@ -69,13 +69,13 @@ __global__ void ComputeOrderKernel(
|
|||||||
in_offset += cur_base_model_seq_lens_this_time;
|
in_offset += cur_base_model_seq_lens_this_time;
|
||||||
} else /*Accept all draft tokens*/ {
|
} else /*Accept all draft tokens*/ {
|
||||||
#ifdef DEBUG_EAGLE_KERNEL
|
#ifdef DEBUG_EAGLE_KERNEL
|
||||||
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
printf("batch %d: accept_num > actual_draft_token_num \n", i);
|
||||||
#endif
|
#endif
|
||||||
position_map[in_offset + accept_num - 2] = out_offset++;
|
position_map[in_offset + accept_num - 2] = out_offset++;
|
||||||
position_map[in_offset + accept_num - 1] = out_offset++;
|
position_map[in_offset + accept_num - 1] = out_offset++;
|
||||||
in_offset += cur_base_model_seq_lens_this_time;
|
in_offset += cur_base_model_seq_lens_this_time;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output_token_num[0] = out_offset;
|
output_token_num[0] = out_offset;
|
||||||
#ifdef DEBUG_EAGLE_KERNEL
|
#ifdef DEBUG_EAGLE_KERNEL
|
||||||
@@ -208,7 +208,7 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
|
|||||||
}
|
}
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||||
input,
|
input,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
|
@@ -72,7 +72,7 @@ __global__ void computeOrderKernel(
|
|||||||
output_token_num[0] = out_offset;
|
output_token_num[0] = out_offset;
|
||||||
#ifdef DEBUG_EAGLE_KERNEL
|
#ifdef DEBUG_EAGLE_KERNEL
|
||||||
printf("position map output_token_num%d:\n", output_token_num[0]);
|
printf("position map output_token_num%d:\n", output_token_num[0]);
|
||||||
for (int i = 0; i < output_token_num[0]; i++) {
|
for (int i = 0; i < output_token_num[0]; i++) {
|
||||||
printf("%d ", src_map[i]);
|
printf("%d ", src_map[i]);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
@@ -187,4 +187,4 @@ PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
|
|||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"step_idx"})
|
"step_idx"})
|
||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
|
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
|
||||||
|
@@ -26,7 +26,7 @@ __global__ void RebuildAppendPaddingKernel(
|
|||||||
const int seq_len,
|
const int seq_len,
|
||||||
const int dim_embed,
|
const int dim_embed,
|
||||||
const size_t elem_nums) {
|
const size_t elem_nums) {
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
LoadT src_vec;
|
LoadT src_vec;
|
||||||
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
|
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
|
||||||
@@ -42,7 +42,7 @@ __global__ void RebuildAppendPaddingKernel(
|
|||||||
|
|
||||||
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
|
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
|
||||||
const int bias_idx = i % dim_embed;
|
const int bias_idx = i % dim_embed;
|
||||||
|
|
||||||
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
|
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
|
||||||
Store<T, VecSize>(src_vec, &out[i]);
|
Store<T, VecSize>(src_vec, &out[i]);
|
||||||
}
|
}
|
||||||
@@ -78,14 +78,14 @@ std::vector<paddle::Tensor> DispatchDtype(
|
|||||||
GetNumBlocks(pack_num, &grid_size);
|
GetNumBlocks(pack_num, &grid_size);
|
||||||
|
|
||||||
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
|
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
|
||||||
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
reinterpret_cast<DataType_*>(out.data<data_t>()),
|
||||||
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
|
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
|
||||||
cum_offsets.data<int32_t>(),
|
cum_offsets.data<int32_t>(),
|
||||||
seq_len_encoder.data<int32_t>(),
|
seq_len_encoder.data<int32_t>(),
|
||||||
seq_len_decoder.data<int32_t>(),
|
seq_len_decoder.data<int32_t>(),
|
||||||
output_padding_offset.data<int32_t>(),
|
output_padding_offset.data<int32_t>(),
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
dim_embed,
|
dim_embed,
|
||||||
elem_nums);
|
elem_nums);
|
||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,7 @@ std::vector<paddle::Tensor> RebuildAppendPadding(
|
|||||||
const paddle::Tensor& output_padding_offset,
|
const paddle::Tensor& output_padding_offset,
|
||||||
const int max_seq_len) {
|
const int max_seq_len) {
|
||||||
|
|
||||||
|
|
||||||
switch (full_hidden_states.dtype()) {
|
switch (full_hidden_states.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
return DispatchDtype<paddle::DataType::BFLOAT16>(
|
||||||
@@ -137,7 +137,7 @@ std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
|
|||||||
|
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
||||||
.Inputs({"full_hidden_states",
|
.Inputs({"full_hidden_states",
|
||||||
"cum_offsets",
|
"cum_offsets",
|
||||||
"seq_len_encoder",
|
"seq_len_encoder",
|
||||||
"seq_len_decoder",
|
"seq_len_decoder",
|
||||||
@@ -146,4 +146,4 @@ PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
|
|||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
|
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
|
||||||
|
@@ -93,7 +93,7 @@ __global__ void speculate_free_and_reschedule(bool *stop_flags,
|
|||||||
used_list_len[tid] = 0;
|
used_list_len[tid] = 0;
|
||||||
}
|
}
|
||||||
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
|
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
|
||||||
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
|
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
|
||||||
1) /
|
1) /
|
||||||
block_size] == -1) {
|
block_size] == -1) {
|
||||||
// 统计需要分配block的位置和总数
|
// 统计需要分配block的位置和总数
|
||||||
@@ -347,7 +347,7 @@ PD_BUILD_STATIC_OP(speculate_step_reschedule)
|
|||||||
"next_tokens",
|
"next_tokens",
|
||||||
"first_token_ids",
|
"first_token_ids",
|
||||||
"accept_num"})
|
"accept_num"})
|
||||||
.Attrs({"block_size: int",
|
.Attrs({"block_size: int",
|
||||||
"encoder_decoder_block_num: int",
|
"encoder_decoder_block_num: int",
|
||||||
"max_draft_tokens: int"})
|
"max_draft_tokens: int"})
|
||||||
.Outputs({"stop_flags_out",
|
.Outputs({"stop_flags_out",
|
||||||
|
@@ -60,7 +60,7 @@ __global__ void recover_block_system_cache(int *recover_block_list, // [bsz]
|
|||||||
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
|
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
|
||||||
ori_free_list_len = ori_free_list_len_tid0;
|
ori_free_list_len = ori_free_list_len_tid0;
|
||||||
#ifdef DEBUG_STEP
|
#ifdef DEBUG_STEP
|
||||||
printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n",
|
printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n",
|
||||||
recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len);
|
recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
|
|||||||
const paddle::Tensor& recover_lens,
|
const paddle::Tensor& recover_lens,
|
||||||
const paddle::Tensor& need_block_list,
|
const paddle::Tensor& need_block_list,
|
||||||
const paddle::Tensor& need_block_len,
|
const paddle::Tensor& need_block_len,
|
||||||
const paddle::Tensor& used_list_len,
|
const paddle::Tensor& used_list_len,
|
||||||
const paddle::Tensor& free_list,
|
const paddle::Tensor& free_list,
|
||||||
const paddle::Tensor& free_list_len,
|
const paddle::Tensor& free_list_len,
|
||||||
const paddle::Tensor& input_ids,
|
const paddle::Tensor& input_ids,
|
||||||
@@ -178,7 +178,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
|
|||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(step_system_cache)
|
PD_BUILD_STATIC_OP(step_system_cache)
|
||||||
.Inputs({"stop_flags",
|
.Inputs({"stop_flags",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"ori_seq_lens_encoder",
|
"ori_seq_lens_encoder",
|
||||||
"ori_seq_lens_decoder",
|
"ori_seq_lens_decoder",
|
||||||
|
@@ -68,26 +68,26 @@ void SwapCache(const paddle::Tensor& cache_gpu, // gpu
|
|||||||
switch (cache_gpu.dtype()) {
|
switch (cache_gpu.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
return SwapCacheImpl<paddle::DataType::BFLOAT16>(
|
return SwapCacheImpl<paddle::DataType::BFLOAT16>(
|
||||||
cache_gpu,
|
cache_gpu,
|
||||||
cache_cpu_ptr,
|
cache_cpu_ptr,
|
||||||
max_block_num_cpu,
|
max_block_num_cpu,
|
||||||
swap_block_ids_gpu,
|
swap_block_ids_gpu,
|
||||||
swap_block_ids_cpu,
|
swap_block_ids_cpu,
|
||||||
mode);
|
mode);
|
||||||
case paddle::DataType::FLOAT16:
|
case paddle::DataType::FLOAT16:
|
||||||
return SwapCacheImpl<paddle::DataType::FLOAT16>(
|
return SwapCacheImpl<paddle::DataType::FLOAT16>(
|
||||||
cache_gpu,
|
cache_gpu,
|
||||||
cache_cpu_ptr,
|
cache_cpu_ptr,
|
||||||
max_block_num_cpu,
|
max_block_num_cpu,
|
||||||
swap_block_ids_gpu,
|
swap_block_ids_gpu,
|
||||||
swap_block_ids_cpu,
|
swap_block_ids_cpu,
|
||||||
mode);
|
mode);
|
||||||
case paddle::DataType::UINT8:
|
case paddle::DataType::UINT8:
|
||||||
return SwapCacheImpl<paddle::DataType::UINT8>(
|
return SwapCacheImpl<paddle::DataType::UINT8>(
|
||||||
cache_gpu,
|
cache_gpu,
|
||||||
cache_cpu_ptr,
|
cache_cpu_ptr,
|
||||||
max_block_num_cpu,
|
max_block_num_cpu,
|
||||||
swap_block_ids_gpu,
|
swap_block_ids_gpu,
|
||||||
swap_block_ids_cpu,
|
swap_block_ids_cpu,
|
||||||
mode);
|
mode);
|
||||||
default:
|
default:
|
||||||
|
@@ -47,7 +47,7 @@ inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* nu
|
|||||||
|
|
||||||
template<typename T, int VecSize>
|
template<typename T, int VecSize>
|
||||||
__global__ void text_image_scatter_kernel(
|
__global__ void text_image_scatter_kernel(
|
||||||
T* input_ptr,
|
T* input_ptr,
|
||||||
T* text_gather_ptr,
|
T* text_gather_ptr,
|
||||||
T* image_gather_ptr,
|
T* image_gather_ptr,
|
||||||
int32_t* token_type_ids,
|
int32_t* token_type_ids,
|
||||||
@@ -72,8 +72,8 @@ __global__ void text_image_scatter_kernel(
|
|||||||
int32_t token_type_ids_num = token_type_ids[token_idx];
|
int32_t token_type_ids_num = token_type_ids[token_idx];
|
||||||
|
|
||||||
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
||||||
|
|
||||||
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int vi = 0; vi < VecSize; ++vi) {
|
for(int vi = 0; vi < VecSize; ++vi) {
|
||||||
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
||||||
@@ -92,7 +92,7 @@ __global__ void text_image_scatter_kernel(
|
|||||||
|
|
||||||
template<typename T, int VecSize>
|
template<typename T, int VecSize>
|
||||||
__global__ void text_image_gather_kernel(
|
__global__ void text_image_gather_kernel(
|
||||||
T* output_ptr,
|
T* output_ptr,
|
||||||
T* text_gather_ptr,
|
T* text_gather_ptr,
|
||||||
T* image_gather_ptr,
|
T* image_gather_ptr,
|
||||||
int32_t* token_type_ids,
|
int32_t* token_type_ids,
|
||||||
@@ -131,8 +131,8 @@ __global__ void text_image_gather_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
|
||||||
|
|
||||||
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
|
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ void LaunchTextImageGatherScatter(
|
|||||||
const int64_t tot_element_num = token_num * hidden_size;
|
const int64_t tot_element_num = token_num * hidden_size;
|
||||||
|
|
||||||
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
|
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
|
||||||
|
|
||||||
const int block_size = 128;
|
const int block_size = 128;
|
||||||
int grid_index = (token_num + block_size - 1) / block_size;
|
int grid_index = (token_num + block_size - 1) / block_size;
|
||||||
constexpr int32_t kNumWaves = 16;
|
constexpr int32_t kNumWaves = 16;
|
||||||
@@ -170,8 +170,8 @@ void LaunchTextImageGatherScatter(
|
|||||||
if (is_scatter) {
|
if (is_scatter) {
|
||||||
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||||
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
||||||
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
||||||
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
||||||
@@ -181,8 +181,8 @@ void LaunchTextImageGatherScatter(
|
|||||||
} else {
|
} else {
|
||||||
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||||
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
||||||
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
|
||||||
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
|
||||||
@@ -216,8 +216,8 @@ void TextImageGatherScatter(
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(text_image_gather_scatter)
|
PD_BUILD_STATIC_OP(text_image_gather_scatter)
|
||||||
.Inputs({"input",
|
.Inputs({"input",
|
||||||
"text_input",
|
"text_input",
|
||||||
"image_input",
|
"image_input",
|
||||||
"token_type_ids",
|
"token_type_ids",
|
||||||
"text_index",
|
"text_index",
|
||||||
"image_index"})
|
"image_index"})
|
||||||
@@ -229,5 +229,5 @@ PD_BUILD_STATIC_OP(text_image_gather_scatter)
|
|||||||
.SetInplaceMap({{"text_input", "text_input_out"},
|
.SetInplaceMap({{"text_input", "text_input_out"},
|
||||||
{"image_input", "image_input_out"},
|
{"image_input", "image_input_out"},
|
||||||
{"text_index", "text_index_out"},
|
{"text_index", "text_index_out"},
|
||||||
{"image_index", "image_index_out"}})
|
{"image_index", "image_index_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
|
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
|
||||||
|
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
template <int VecSize>
|
template <int VecSize>
|
||||||
__global__ void text_image_index_out_kernel(
|
__global__ void text_image_index_out_kernel(
|
||||||
int32_t* token_type_ids,
|
int32_t* token_type_ids,
|
||||||
int32_t* text_index,
|
int32_t* text_index,
|
||||||
int32_t* image_index,
|
int32_t* image_index,
|
||||||
const int64_t token_num
|
const int64_t token_num
|
||||||
@@ -25,7 +25,7 @@ __global__ void text_image_index_out_kernel(
|
|||||||
if (global_thread_idx >= 1) return;
|
if (global_thread_idx >= 1) return;
|
||||||
int text_count = 0;
|
int text_count = 0;
|
||||||
int images_count = 0;
|
int images_count = 0;
|
||||||
|
|
||||||
for (int i = 0; i < token_num; ++i) {
|
for (int i = 0; i < token_num; ++i) {
|
||||||
// printf(" %d %d %d %d \n", text_index[i], text_count, images_count, i);
|
// printf(" %d %d %d %d \n", text_index[i], text_count, images_count, i);
|
||||||
if (token_type_ids[i] == 0) {
|
if (token_type_ids[i] == 0) {
|
||||||
@@ -60,5 +60,5 @@ PD_BUILD_STATIC_OP(text_image_index_out)
|
|||||||
.Outputs({"text_index_out",
|
.Outputs({"text_index_out",
|
||||||
"image_index_out"})
|
"image_index_out"})
|
||||||
.SetInplaceMap({{"text_index", "text_index_out"},
|
.SetInplaceMap({{"text_index", "text_index_out"},
|
||||||
{"image_index", "image_index_out"}})
|
{"image_index", "image_index_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
|
.SetKernelFn(PD_KERNEL(TextImageIndexOut));
|
||||||
|
@@ -810,4 +810,4 @@ PD_BUILD_STATIC_OP(tune_cublaslt_gemm)
|
|||||||
"is_test: bool",
|
"is_test: bool",
|
||||||
"is_read_from_file: bool",
|
"is_read_from_file: bool",
|
||||||
"path: std::string"})
|
"path: std::string"})
|
||||||
.SetKernelFn(PD_KERNEL(TuneCublasltGemm));
|
.SetKernelFn(PD_KERNEL(TuneCublasltGemm));
|
||||||
|
@@ -33,7 +33,7 @@ __global__ void update_inputs_beam_kernel(
|
|||||||
if (block_idx == 0) {
|
if (block_idx == 0) {
|
||||||
seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index];
|
seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index];
|
||||||
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
|
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
|
||||||
}
|
}
|
||||||
if (block_idx < seq_len) {
|
if (block_idx < seq_len) {
|
||||||
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
|
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
|
||||||
}
|
}
|
||||||
@@ -74,8 +74,8 @@ void UpdateInputesBeam(
|
|||||||
|
|
||||||
PD_BUILD_STATIC_OP(update_inputs_beam)
|
PD_BUILD_STATIC_OP(update_inputs_beam)
|
||||||
.Inputs({"beam_width",
|
.Inputs({"beam_width",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"seq_lens_encoder",
|
"seq_lens_encoder",
|
||||||
"input_ids",
|
"input_ids",
|
||||||
"logits"})
|
"logits"})
|
||||||
.Outputs({"seq_lens_this_time_out",
|
.Outputs({"seq_lens_this_time_out",
|
||||||
@@ -86,4 +86,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
|
|||||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||||
{"input_ids", "input_ids_out"},
|
{"input_ids", "input_ids_out"},
|
||||||
{"logits", "logits_out"}})
|
{"logits", "logits_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
|
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
|
||||||
|
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" setup for FastDeploy custom ops """
|
"""setup for FastDeploy custom ops"""
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -41,8 +41,7 @@ ROOT_DIR = Path(__file__).parent.parent
|
|||||||
|
|
||||||
# cannot import envs directly because it depends on fastdeploy,
|
# cannot import envs directly because it depends on fastdeploy,
|
||||||
# which is not installed yet
|
# which is not installed yet
|
||||||
envs = load_module_from_path('envs',
|
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
|
||||||
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
|
|
||||||
|
|
||||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||||
@@ -143,8 +142,7 @@ def get_nvcc_version():
|
|||||||
"""
|
"""
|
||||||
Get cuda version of nvcc.
|
Get cuda version of nvcc.
|
||||||
"""
|
"""
|
||||||
nvcc_output = subprocess.check_output(["nvcc", "--version"],
|
nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
|
||||||
universal_newlines=True)
|
|
||||||
output = nvcc_output.split()
|
output = nvcc_output.split()
|
||||||
release_idx = output.index("release") + 1
|
release_idx = output.index("release") + 1
|
||||||
nvcc_cuda_version = float(output[release_idx].split(",")[0])
|
nvcc_cuda_version = float(output[release_idx].split(",")[0])
|
||||||
@@ -160,13 +158,19 @@ def get_gencode_flags(archs):
|
|||||||
for cc_val in cc_s:
|
for cc_val in cc_s:
|
||||||
if cc_val == 90:
|
if cc_val == 90:
|
||||||
arch_code = "90a"
|
arch_code = "90a"
|
||||||
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
|
flags += [
|
||||||
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
|
"-gencode",
|
||||||
|
f"arch=compute_{arch_code},code=sm_{arch_code}",
|
||||||
|
]
|
||||||
|
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
|
||||||
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
|
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
|
||||||
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
|
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
|
||||||
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
|
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
|
||||||
arch_code = "100a"
|
arch_code = "100a"
|
||||||
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
|
flags += [
|
||||||
|
"-gencode",
|
||||||
|
f"arch=compute_{arch_code},code=sm_{arch_code}",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
|
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
|
||||||
return flags
|
return flags
|
||||||
@@ -194,7 +198,7 @@ if paddle.is_compiled_with_rocm():
|
|||||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||||
if not os.listdir(json_dir):
|
if not os.listdir(json_dir):
|
||||||
raise ValueError("Git clone nlohmann_json failed!")
|
raise ValueError("Git clone nlohmann_json failed!")
|
||||||
sources=[
|
sources = [
|
||||||
"gpu_ops/set_value_by_flags.cu",
|
"gpu_ops/set_value_by_flags.cu",
|
||||||
"gpu_ops/token_penalty_multi_scores.cu",
|
"gpu_ops/token_penalty_multi_scores.cu",
|
||||||
"gpu_ops/stop_generation.cu",
|
"gpu_ops/stop_generation.cu",
|
||||||
@@ -302,8 +306,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||||
if not os.path.exists(cutlass_dir):
|
if not os.path.exists(cutlass_dir):
|
||||||
os.makedirs(cutlass_dir)
|
os.makedirs(cutlass_dir)
|
||||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git",
|
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||||
cutlass_dir)
|
|
||||||
if not os.listdir(cutlass_dir):
|
if not os.listdir(cutlass_dir):
|
||||||
raise ValueError("Git clone cutlass failed!")
|
raise ValueError("Git clone cutlass failed!")
|
||||||
|
|
||||||
@@ -312,8 +315,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||||
if not os.path.exists(deep_gemm_dir):
|
if not os.path.exists(deep_gemm_dir):
|
||||||
os.makedirs(deep_gemm_dir)
|
os.makedirs(deep_gemm_dir)
|
||||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git",
|
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||||
deep_gemm_dir)
|
|
||||||
if not os.listdir(deep_gemm_dir):
|
if not os.listdir(deep_gemm_dir):
|
||||||
raise ValueError("Git clone DeepGEMM failed!")
|
raise ValueError("Git clone DeepGEMM failed!")
|
||||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -347,15 +349,13 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
try:
|
try:
|
||||||
shutil.copytree(src_dir, dst_dir)
|
shutil.copytree(src_dir, dst_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||||
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
|
||||||
|
|
||||||
json_dir = "third_party/nlohmann_json"
|
json_dir = "third_party/nlohmann_json"
|
||||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||||
if not os.path.exists(json_dir):
|
if not os.path.exists(json_dir):
|
||||||
os.makedirs(json_dir)
|
os.makedirs(json_dir)
|
||||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git",
|
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||||
json_dir)
|
|
||||||
if not os.listdir(json_dir):
|
if not os.listdir(json_dir):
|
||||||
raise ValueError("Git clone nlohmann_json failed!")
|
raise ValueError("Git clone nlohmann_json failed!")
|
||||||
|
|
||||||
@@ -372,7 +372,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"-Ithird_party/nlohmann_json/include",
|
"-Ithird_party/nlohmann_json/include",
|
||||||
]
|
]
|
||||||
nvcc_version = get_nvcc_version()
|
nvcc_version = get_nvcc_version()
|
||||||
print(f'nvcc_version = {nvcc_version}')
|
print(f"nvcc_version = {nvcc_version}")
|
||||||
if nvcc_version >= 12.0:
|
if nvcc_version >= 12.0:
|
||||||
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
|
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
|
||||||
cc = max(get_sm_version(archs))
|
cc = max(get_sm_version(archs))
|
||||||
@@ -414,31 +414,24 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
# Running generate fp8 gemm codes.
|
# Running generate fp8 gemm codes.
|
||||||
# Common for SM89, SM90, SM100 (Blackwell)
|
# Common for SM89, SM90, SM100 (Blackwell)
|
||||||
nvcc_compile_args += ["-DENABLE_FP8"]
|
nvcc_compile_args += ["-DENABLE_FP8"]
|
||||||
nvcc_compile_args += [
|
nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
|
||||||
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
|
|
||||||
]
|
|
||||||
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
|
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
|
||||||
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
|
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
|
||||||
|
|
||||||
if cc >= 90: # Hopper and newer
|
if cc >= 90: # Hopper and newer
|
||||||
# SM90 (Hopper) specific auto-generation and flags
|
# SM90 (Hopper) specific auto-generation and flags
|
||||||
if cc == 90: # Only for SM90
|
if cc == 90: # Only for SM90
|
||||||
nvcc_compile_args += [
|
nvcc_compile_args += [
|
||||||
# The gencode for 90a is added in get_gencode_flags now
|
# The gencode for 90a is added in get_gencode_flags now
|
||||||
# "-gencode",
|
# "-gencode",
|
||||||
# "arch=compute_90a,code=compute_90a",
|
# "arch=compute_90a,code=compute_90a",
|
||||||
"-O3",
|
"-O3",
|
||||||
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
|
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
|
||||||
]
|
]
|
||||||
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
|
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
|
||||||
os.system(
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
|
||||||
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
|
||||||
os.system(
|
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
|
||||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
|
|
||||||
)
|
|
||||||
os.system(
|
|
||||||
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
nvcc_compile_args += [
|
nvcc_compile_args += [
|
||||||
"-DENABLE_SCALED_MM_SM90=1",
|
"-DENABLE_SCALED_MM_SM90=1",
|
||||||
@@ -450,14 +443,14 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
|
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
|
||||||
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
||||||
]
|
]
|
||||||
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
|
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
|
||||||
print("SM100 (Blackwell): Applying SM100 configurations.")
|
print("SM100 (Blackwell): Applying SM100 configurations.")
|
||||||
nvcc_compile_args += [
|
nvcc_compile_args += [
|
||||||
# The gencode for 100a is added in get_gencode_flags
|
# The gencode for 100a is added in get_gencode_flags
|
||||||
# "-gencode",
|
# "-gencode",
|
||||||
# "arch=compute_100a,code=compute_100a",
|
# "arch=compute_100a,code=compute_100a",
|
||||||
"-O3", # Common optimization flag
|
"-O3", # Common optimization flag
|
||||||
"-DNDEBUG", # Common debug flag
|
"-DNDEBUG", # Common debug flag
|
||||||
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
|
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
|
||||||
]
|
]
|
||||||
# Placeholder for SM100-specific kernel auto-generation scripts
|
# Placeholder for SM100-specific kernel auto-generation scripts
|
||||||
@@ -469,18 +462,16 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
|
|
||||||
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
|
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
|
||||||
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
|
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
|
||||||
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
|
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
|
||||||
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
|
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
|
||||||
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
|
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
|
||||||
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
||||||
os.system(
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
||||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
|
||||||
|
|
||||||
else: # For cc == 89 (Ada)
|
else: # For cc == 89 (Ada)
|
||||||
print("SM89: Running generic FP8 kernel auto-generation.")
|
print("SM89: Running generic FP8 kernel auto-generation.")
|
||||||
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
||||||
os.system(
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
||||||
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
|
||||||
|
|
||||||
# Common FP8 sources for SM89+
|
# Common FP8 sources for SM89+
|
||||||
sources += [
|
sources += [
|
||||||
@@ -493,7 +484,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
|
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
|
||||||
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
|
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
|
||||||
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
|
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
|
||||||
"gpu_ops/fused_hadamard_quant_fp8.cu"
|
"gpu_ops/fused_hadamard_quant_fp8.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
||||||
|
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" setup for FASTDEPLOY base ops """
|
"""setup for FASTDEPLOY base ops"""
|
||||||
|
|
||||||
from paddle.utils.cpp_extension import CppExtension, setup
|
from paddle.utils.cpp_extension import CppExtension, setup
|
||||||
|
|
||||||
@@ -27,7 +27,8 @@ setup(
|
|||||||
"cpu_ops/rebuild_padding.cc",
|
"cpu_ops/rebuild_padding.cc",
|
||||||
],
|
],
|
||||||
extra_compile_args=[
|
extra_compile_args=[
|
||||||
"-DPy_LIMITED_API=0x03090000", "-DPADDLE_ON_INFERENCE"
|
"-DPy_LIMITED_API=0x03090000",
|
||||||
|
"-DPADDLE_ON_INFERENCE",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@@ -11,7 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" setup for FASTDEPLOY custom cpu ops """
|
"""setup for FASTDEPLOY custom cpu ops"""
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import tarfile
|
import tarfile
|
||||||
@@ -26,8 +26,7 @@ ROOT_DIR = Path(__file__).parent.parent
|
|||||||
# which is not installed yet
|
# which is not installed yet
|
||||||
from .setup_ops import load_module_from_path
|
from .setup_ops import load_module_from_path
|
||||||
|
|
||||||
envs = load_module_from_path('envs',
|
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
|
||||||
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
|
|
||||||
|
|
||||||
BUILDING_ARCS = []
|
BUILDING_ARCS = []
|
||||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||||
|
@@ -48,17 +48,26 @@ def get_candidate_configs(sm):
|
|||||||
candidate_configs = list()
|
candidate_configs = list()
|
||||||
|
|
||||||
hasbias = ("false", "true")
|
hasbias = ("false", "true")
|
||||||
KernelSchedule = (
|
KernelSchedule = ("KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>",)
|
||||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
|
EpilogueSchedule = ("TmaWarpSpecializedCooperative",)
|
||||||
EpilogueSchedule = ("TmaWarpSpecializedCooperative", )
|
|
||||||
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
|
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
|
||||||
for act_tag in [
|
for act_tag in [
|
||||||
("noact", "Identity"),
|
("noact", "Identity"),
|
||||||
# ("relu", "ReLu"),
|
# ("relu", "ReLu"),
|
||||||
# ("gelu", "GELU"),
|
# ("gelu", "GELU"),
|
||||||
]:
|
]:
|
||||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
candidate_configs.extend(
|
||||||
EpilogueSchedule, TileSchedule)])
|
[
|
||||||
|
(
|
||||||
|
hasbias,
|
||||||
|
act_tag,
|
||||||
|
tiles,
|
||||||
|
KernelSchedule,
|
||||||
|
EpilogueSchedule,
|
||||||
|
TileSchedule,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
return candidate_configs
|
return candidate_configs
|
||||||
|
|
||||||
|
|
||||||
@@ -66,16 +75,13 @@ def get_shape_str(tile_shape):
|
|||||||
"""
|
"""
|
||||||
return tile_shape string.
|
return tile_shape string.
|
||||||
"""
|
"""
|
||||||
blocks, clusters = [
|
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
|
||||||
]
|
|
||||||
blocks = [elem.strip("_") for elem in blocks]
|
blocks = [elem.strip("_") for elem in blocks]
|
||||||
clusters = [elem.strip("_") for elem in clusters]
|
clusters = [elem.strip("_") for elem in clusters]
|
||||||
return blocks, clusters
|
return blocks, clusters
|
||||||
|
|
||||||
|
|
||||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule,
|
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, tile_schedule):
|
||||||
tile_schedule):
|
|
||||||
"""
|
"""
|
||||||
check the cutlass config valid.
|
check the cutlass config valid.
|
||||||
"""
|
"""
|
||||||
@@ -304,13 +310,10 @@ def SubstituteTemplate(template, values_base):
|
|||||||
SubstituteTemplate
|
SubstituteTemplate
|
||||||
"""
|
"""
|
||||||
values = copy.deepcopy(values_base)
|
values = copy.deepcopy(values_base)
|
||||||
if values.get("KernelSchedule"
|
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
|
||||||
) is not None and "Auto" in values["KernelSchedule"]:
|
|
||||||
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
||||||
if values.get("EpilogueSchedule"
|
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
|
||||||
) is not None and "Auto" in values["EpilogueSchedule"]:
|
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
||||||
values[
|
|
||||||
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
|
||||||
text = template
|
text = template
|
||||||
changed = True
|
changed = True
|
||||||
while changed:
|
while changed:
|
||||||
@@ -329,8 +332,7 @@ def parse_args():
|
|||||||
parse_args
|
parse_args
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=
|
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
@@ -346,15 +348,15 @@ def parse_args():
|
|||||||
|
|
||||||
# generate source .cu
|
# generate source .cu
|
||||||
def generate_source_cu(
|
def generate_source_cu(
|
||||||
inputs_type: (str),
|
inputs_type: str,
|
||||||
outputs_type: (str),
|
outputs_type: str,
|
||||||
hasbiases: (str),
|
hasbiases: str,
|
||||||
act_tag: (str),
|
act_tag: str,
|
||||||
tiles: (str),
|
tiles: str,
|
||||||
KernelSchedule: (str),
|
KernelSchedule: str,
|
||||||
EpilogueSchedule: (str),
|
EpilogueSchedule: str,
|
||||||
TileSchedule: (str),
|
TileSchedule: str,
|
||||||
sm: str,
|
sm: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
generate_source_cu
|
generate_source_cu
|
||||||
@@ -369,8 +371,11 @@ def generate_source_cu(
|
|||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
for tile_schedule in TileSchedule:
|
for tile_schedule in TileSchedule:
|
||||||
if not check_config_valid(
|
if not check_config_valid(
|
||||||
tile_config, kernel_schedule,
|
tile_config,
|
||||||
epilogue_schedule, tile_schedule):
|
kernel_schedule,
|
||||||
|
epilogue_schedule,
|
||||||
|
tile_schedule,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"input_type": input_type,
|
"input_type": input_type,
|
||||||
@@ -385,30 +390,32 @@ def generate_source_cu(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(
|
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||||
GemmDeclare, value_dict)
|
|
||||||
|
|
||||||
return all_code
|
return all_code
|
||||||
|
|
||||||
|
|
||||||
# generate gemm launch .cu
|
# generate gemm launch .cu
|
||||||
def generate_launch_gemm_cus(
|
def generate_launch_gemm_cus(
|
||||||
generate_dir: (str), inputs_type: (str), outputs_type: (str),
|
generate_dir: str,
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
inputs_type: str,
|
||||||
|
outputs_type: str,
|
||||||
|
fuse_gemm_configs: tuple,
|
||||||
|
sm: str,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
generate_launch_gemm_cus
|
generate_launch_gemm_cus
|
||||||
"""
|
"""
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
|
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
TileSchedule: (str) = single_config[5]
|
TileSchedule: str = single_config[5]
|
||||||
code_map = {}
|
code_map = {}
|
||||||
head_path = os.path.join(generate_dir,
|
head_path = os.path.join(generate_dir, f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
|
||||||
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
|
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
for tile_config in tiles:
|
for tile_config in tiles:
|
||||||
blocks, clusters = get_shape_str(tile_config)
|
blocks, clusters = get_shape_str(tile_config)
|
||||||
@@ -418,19 +425,19 @@ def generate_launch_gemm_cus(
|
|||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
for tile_schedule in TileSchedule:
|
for tile_schedule in TileSchedule:
|
||||||
if not check_config_valid(tile_config, kernel_schedule,
|
if not check_config_valid(
|
||||||
epilogue_schedule,
|
tile_config,
|
||||||
tile_schedule):
|
kernel_schedule,
|
||||||
|
epilogue_schedule,
|
||||||
|
tile_schedule,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm":
|
"sm": sm[-2:],
|
||||||
sm[-2:],
|
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
|
||||||
"gemm_config":
|
|
||||||
gemm_config_str.replace("<", "").replace(">", ""),
|
|
||||||
}
|
}
|
||||||
head_all_code += SubstituteTemplate(
|
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||||
LaunchGemmDeclare, value_dict)
|
|
||||||
os.makedirs(generate_dir, exist_ok=True)
|
os.makedirs(generate_dir, exist_ok=True)
|
||||||
with open(head_path, "w") as f:
|
with open(head_path, "w") as f:
|
||||||
f.write(head_all_code)
|
f.write(head_all_code)
|
||||||
@@ -444,19 +451,19 @@ def generate_launch_gemm_cus(
|
|||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
for tile_schedule in TileSchedule:
|
for tile_schedule in TileSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(
|
||||||
epilogue_schedule,
|
tile_shape,
|
||||||
tile_schedule):
|
kernel_schedule,
|
||||||
|
epilogue_schedule,
|
||||||
|
tile_schedule,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm":
|
"sm": sm[-2:],
|
||||||
sm[-2:],
|
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
|
||||||
"gemm_config":
|
|
||||||
gemm_config_str.replace("<", "").replace(">", ""),
|
|
||||||
}
|
}
|
||||||
source_all_code = SubstituteTemplate(
|
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||||
LaunchGemmPart0, value_dict)
|
|
||||||
type_id = 0
|
type_id = 0
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
for output_type in outputs_type:
|
for output_type in outputs_type:
|
||||||
@@ -476,16 +483,14 @@ def generate_launch_gemm_cus(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
gemm_config_str = gemm_config_str.replace("<", "").replace(
|
gemm_config_str = gemm_config_str.replace("<", "").replace(">", "")
|
||||||
">", "")
|
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(
|
||||||
generate_dir,
|
generate_dir,
|
||||||
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
|
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
|
||||||
)
|
)
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
@@ -495,19 +500,18 @@ def generate_launch_gemm_cus(
|
|||||||
|
|
||||||
|
|
||||||
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
||||||
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
|
||||||
"""
|
"""
|
||||||
generate_dispatch_gemm_cu
|
generate_dispatch_gemm_cu
|
||||||
"""
|
"""
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
|
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
TileSchedule: (str) = single_config[5]
|
TileSchedule: str = single_config[5]
|
||||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||||
type_id = 0
|
type_id = 0
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
@@ -530,9 +534,12 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
for tile_schedule in TileSchedule:
|
for tile_schedule in TileSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(
|
||||||
epilogue_schedule,
|
tile_shape,
|
||||||
tile_schedule):
|
kernel_schedule,
|
||||||
|
epilogue_schedule,
|
||||||
|
tile_schedule,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"TileShape": tile_shape[0],
|
"TileShape": tile_shape[0],
|
||||||
@@ -554,18 +561,18 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
|||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
for tile_schedule in TileSchedule:
|
for tile_schedule in TileSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(
|
||||||
epilogue_schedule,
|
tile_shape,
|
||||||
tile_schedule):
|
kernel_schedule,
|
||||||
|
epilogue_schedule,
|
||||||
|
tile_schedule,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm":
|
"sm": sm[-2:],
|
||||||
sm[-2:],
|
"tile_id": str(tile_id),
|
||||||
"tile_id":
|
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
|
||||||
str(tile_id),
|
|
||||||
"gemm_config":
|
|
||||||
gemm_config_str.replace("<", "").replace(">", ""),
|
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||||
tile_id += 1
|
tile_id += 1
|
||||||
@@ -610,12 +617,17 @@ if __name__ == "__main__":
|
|||||||
f.close()
|
f.close()
|
||||||
# Compile parallelization
|
# Compile parallelization
|
||||||
generate_launch_gemm_cus(
|
generate_launch_gemm_cus(
|
||||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||||
outputs_type, fuse_gemm_configs, sm_dict[sm])
|
inputs_type,
|
||||||
|
outputs_type,
|
||||||
|
fuse_gemm_configs,
|
||||||
|
sm_dict[sm],
|
||||||
|
)
|
||||||
|
|
||||||
# hard code for act_tag
|
# hard code for act_tag
|
||||||
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
|
file_name = (
|
||||||
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu")
|
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu"
|
||||||
|
)
|
||||||
all_code = generate_dispatch_gemm_cu(
|
all_code = generate_dispatch_gemm_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
outputs_type,
|
outputs_type,
|
||||||
|
@@ -24,27 +24,28 @@ def get_candidate_tiles():
|
|||||||
"""
|
"""
|
||||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||||
|
|
||||||
base_configs.extend([
|
base_configs.extend(
|
||||||
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
[
|
||||||
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
])
|
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return base_configs
|
return base_configs
|
||||||
|
|
||||||
|
|
||||||
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages,
|
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
|
||||||
max_stages):
|
|
||||||
"""
|
"""
|
||||||
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
|
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
|
||||||
"""
|
"""
|
||||||
@@ -299,8 +300,7 @@ def check_min_split_k(value):
|
|||||||
"""
|
"""
|
||||||
ivalue = int(value)
|
ivalue = int(value)
|
||||||
if ivalue > 1:
|
if ivalue > 1:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
|
||||||
"Dual gemm split_k mode is not support.")
|
|
||||||
return ivalue
|
return ivalue
|
||||||
|
|
||||||
|
|
||||||
@@ -310,8 +310,7 @@ def check_max_split_k(value):
|
|||||||
"""
|
"""
|
||||||
ivalue = int(value)
|
ivalue = int(value)
|
||||||
if ivalue > 1:
|
if ivalue > 1:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
|
||||||
"Dual gemm split_k mode is not support..")
|
|
||||||
return ivalue
|
return ivalue
|
||||||
|
|
||||||
|
|
||||||
@@ -320,8 +319,7 @@ def parse_args():
|
|||||||
parse_args
|
parse_args
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=
|
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
@@ -421,8 +419,7 @@ def generate_dual_gemm_source_cu(
|
|||||||
"hasbias": hasbias,
|
"hasbias": hasbias,
|
||||||
"SM": sm,
|
"SM": sm,
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(
|
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
|
||||||
GemmSplitKDeclare, value_dict)
|
|
||||||
|
|
||||||
all_code += CommonTail
|
all_code += CommonTail
|
||||||
return all_code
|
return all_code
|
||||||
@@ -449,12 +446,12 @@ def generate_launch_dual_gemm_cus(
|
|||||||
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
|
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
gemm_config = (
|
||||||
]
|
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
)
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
@@ -467,12 +464,12 @@ def generate_launch_dual_gemm_cus(
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
gemm_config = (
|
||||||
]
|
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
)
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
@@ -498,16 +495,14 @@ def generate_launch_dual_gemm_cus(
|
|||||||
"num_stages": str(stage),
|
"num_stages": str(stage),
|
||||||
"SM": sm,
|
"SM": sm,
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
|
||||||
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
|
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
# source_all_code += split_k_code
|
# source_all_code += split_k_code
|
||||||
# source_all_code += LaunchGemmPart4
|
# source_all_code += LaunchGemmPart4
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
|
||||||
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
|
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
f.close()
|
f.close()
|
||||||
@@ -566,12 +561,12 @@ def generate_dispatch_dual_gemm_cu(
|
|||||||
|
|
||||||
tile_id = 0
|
tile_id = 0
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
gemm_config = (
|
||||||
]
|
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
||||||
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
|
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
||||||
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
|
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
|
)
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
@@ -580,10 +575,12 @@ def generate_dispatch_dual_gemm_cu(
|
|||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||||
tile_id += 1
|
tile_id += 1
|
||||||
value_dict.update({
|
value_dict.update(
|
||||||
"min_split_k": str(min_split_k),
|
{
|
||||||
"max_split_k": str(max_split_k),
|
"min_split_k": str(min_split_k),
|
||||||
})
|
"max_split_k": str(max_split_k),
|
||||||
|
}
|
||||||
|
)
|
||||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||||
return all_code
|
return all_code
|
||||||
|
|
||||||
@@ -602,8 +599,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for sm in archs:
|
for sm in archs:
|
||||||
if sm == "89":
|
if sm == "89":
|
||||||
fuse_gemm_configs = get_dual_gemm_candidate_configs(
|
fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
|
||||||
sm, min_split_k, max_split_k, min_stages, max_stages)
|
|
||||||
for fuse_gemm_config in fuse_gemm_configs:
|
for fuse_gemm_config in fuse_gemm_configs:
|
||||||
file_name = (
|
file_name = (
|
||||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||||
|
@@ -19,8 +19,7 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
def get_candidate_tiles():
|
def get_candidate_tiles():
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
cta_shape = [
|
cta_shape = [
|
||||||
("<_64, _16, _128>"),
|
("<_64, _16, _128>"),
|
||||||
("<_64, _32, _128>"),
|
("<_64, _32, _128>"),
|
||||||
@@ -45,8 +44,7 @@ def get_candidate_tiles():
|
|||||||
|
|
||||||
|
|
||||||
def get_dual_gemm_candidate_configs(sm):
|
def get_dual_gemm_candidate_configs(sm):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
tiles = get_candidate_tiles()
|
tiles = get_candidate_tiles()
|
||||||
candidate_configs = list()
|
candidate_configs = list()
|
||||||
|
|
||||||
@@ -64,35 +62,27 @@ def get_dual_gemm_candidate_configs(sm):
|
|||||||
("swiglu", "SiLu"),
|
("swiglu", "SiLu"),
|
||||||
("geglu", "GELU"),
|
("geglu", "GELU"),
|
||||||
]:
|
]:
|
||||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
|
||||||
EpilogueSchedule)])
|
|
||||||
|
|
||||||
return candidate_configs
|
return candidate_configs
|
||||||
|
|
||||||
|
|
||||||
def get_shape_str(tile_shape):
|
def get_shape_str(tile_shape):
|
||||||
"""
|
""" """
|
||||||
"""
|
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
|
||||||
blocks, clusters = [
|
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
|
||||||
]
|
|
||||||
blocks = [elem.strip("_") for elem in blocks]
|
blocks = [elem.strip("_") for elem in blocks]
|
||||||
clusters = [elem.strip("_") for elem in clusters]
|
clusters = [elem.strip("_") for elem in clusters]
|
||||||
return blocks, clusters
|
return blocks, clusters
|
||||||
|
|
||||||
|
|
||||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
blocks, clusters = get_shape_str(tile_shape)
|
blocks, clusters = get_shape_str(tile_shape)
|
||||||
if int(
|
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
||||||
blocks[0]
|
|
||||||
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
|
||||||
return False
|
return False
|
||||||
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
||||||
return False
|
return False
|
||||||
if tile_shape[
|
if tile_shape[0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
|
||||||
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -302,8 +292,7 @@ bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
|
|||||||
|
|
||||||
|
|
||||||
def SubstituteTemplate(template, values):
|
def SubstituteTemplate(template, values):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
text = template
|
text = template
|
||||||
changed = True
|
changed = True
|
||||||
while changed:
|
while changed:
|
||||||
@@ -318,10 +307,8 @@ def SubstituteTemplate(template, values):
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
""" """
|
||||||
"""
|
parser = argparse.ArgumentParser(description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -336,17 +323,16 @@ def parse_args():
|
|||||||
|
|
||||||
# generate source .cu
|
# generate source .cu
|
||||||
def generate_dual_gemm_source_cu(
|
def generate_dual_gemm_source_cu(
|
||||||
inputs_type: (str),
|
inputs_type: str,
|
||||||
biases_type: (str),
|
biases_type: str,
|
||||||
hasbiases: (str),
|
hasbiases: str,
|
||||||
act_tag: (str),
|
act_tag: str,
|
||||||
tiles: (str),
|
tiles: str,
|
||||||
KernelSchedule: (str),
|
KernelSchedule: str,
|
||||||
EpilogueSchedule: (str),
|
EpilogueSchedule: str,
|
||||||
sm: str,
|
sm: str,
|
||||||
):
|
):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
all_code = CommonHead
|
all_code = CommonHead
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
for bias_type in biases_type:
|
for bias_type in biases_type:
|
||||||
@@ -354,9 +340,7 @@ def generate_dual_gemm_source_cu(
|
|||||||
for tile_config in tiles:
|
for tile_config in tiles:
|
||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_config,
|
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
|
||||||
kernel_schedule,
|
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"input_type": input_type,
|
"input_type": input_type,
|
||||||
@@ -370,28 +354,29 @@ def generate_dual_gemm_source_cu(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(
|
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||||
GemmDeclare, value_dict)
|
|
||||||
|
|
||||||
return all_code
|
return all_code
|
||||||
|
|
||||||
|
|
||||||
# generate gemm launch .cu
|
# generate gemm launch .cu
|
||||||
def generate_launch_dual_gemm_cus(
|
def generate_launch_dual_gemm_cus(
|
||||||
generate_dir: (str), inputs_type: (str), biases_type: (str),
|
generate_dir: str,
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
inputs_type: str,
|
||||||
"""
|
biases_type: str,
|
||||||
"""
|
fuse_gemm_configs: tuple,
|
||||||
|
sm: str,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
|
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
code_map = {}
|
code_map = {}
|
||||||
head_path = os.path.join(generate_dir,
|
head_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
|
||||||
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
|
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
for tile_config in tiles:
|
for tile_config in tiles:
|
||||||
blocks, clusters = get_shape_str(tile_config)
|
blocks, clusters = get_shape_str(tile_config)
|
||||||
@@ -401,16 +386,14 @@ def generate_launch_dual_gemm_cus(
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_config, kernel_schedule,
|
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
"gemm_config": gemm_config_str,
|
"gemm_config": gemm_config_str,
|
||||||
}
|
}
|
||||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
|
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||||
value_dict)
|
|
||||||
os.makedirs(generate_dir, exist_ok=True)
|
os.makedirs(generate_dir, exist_ok=True)
|
||||||
with open(head_path, "w") as f:
|
with open(head_path, "w") as f:
|
||||||
f.write(head_all_code)
|
f.write(head_all_code)
|
||||||
@@ -422,16 +405,14 @@ def generate_launch_dual_gemm_cus(
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
"gemm_config": gemm_config_str,
|
"gemm_config": gemm_config_str,
|
||||||
}
|
}
|
||||||
source_all_code = SubstituteTemplate(LaunchGemmPart0,
|
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||||
value_dict)
|
|
||||||
type_id = 0
|
type_id = 0
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
for bias_type in biases_type:
|
for bias_type in biases_type:
|
||||||
@@ -450,14 +431,13 @@ def generate_launch_dual_gemm_cus(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(
|
||||||
generate_dir,
|
generate_dir,
|
||||||
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
|
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
|
||||||
)
|
)
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
@@ -467,16 +447,14 @@ def generate_launch_dual_gemm_cus(
|
|||||||
|
|
||||||
|
|
||||||
# generate fp8_fp8_gemm_scale_bias_act.cu
|
# generate fp8_fp8_gemm_scale_bias_act.cu
|
||||||
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
|
def generate_dispatch_dual_gemm_cu(inputs_type: str, biases_type: str, fuse_gemm_configs: tuple, sm: str):
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
""" """
|
||||||
"""
|
|
||||||
"""
|
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
|
|
||||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||||
type_id = 0
|
type_id = 0
|
||||||
@@ -500,8 +478,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
|
|||||||
for tile_shape in tiles:
|
for tile_shape in tiles:
|
||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"TileShape": tile_shape[0],
|
"TileShape": tile_shape[0],
|
||||||
@@ -520,8 +497,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
@@ -570,12 +546,15 @@ if __name__ == "__main__":
|
|||||||
f.close()
|
f.close()
|
||||||
# Compile parallelization
|
# Compile parallelization
|
||||||
generate_launch_dual_gemm_cus(
|
generate_launch_dual_gemm_cus(
|
||||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||||
biases_type, fuse_gemm_configs, sm_dict[sm])
|
inputs_type,
|
||||||
|
biases_type,
|
||||||
|
fuse_gemm_configs,
|
||||||
|
sm_dict[sm],
|
||||||
|
)
|
||||||
# hard code for act_tag
|
# hard code for act_tag
|
||||||
file_name = (
|
file_name = (
|
||||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
|
||||||
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
|
|
||||||
)
|
)
|
||||||
all_code = generate_dispatch_dual_gemm_cu(
|
all_code = generate_dispatch_dual_gemm_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
|
@@ -31,25 +31,26 @@ def get_candidate_tiles():
|
|||||||
"""
|
"""
|
||||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||||
|
|
||||||
base_configs.extend([
|
base_configs.extend(
|
||||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
[
|
||||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
])
|
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return base_configs
|
return base_configs
|
||||||
|
|
||||||
|
|
||||||
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages,
|
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
|
||||||
max_stages):
|
|
||||||
"""
|
"""
|
||||||
获取候选的gemm算子配置列表。
|
获取候选的gemm算子配置列表。
|
||||||
|
|
||||||
@@ -353,8 +354,7 @@ def parse_args():
|
|||||||
代码参数解析
|
代码参数解析
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=
|
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
@@ -448,8 +448,7 @@ def generate_source_cu(
|
|||||||
"hasbias": hasbias,
|
"hasbias": hasbias,
|
||||||
"SM": sm,
|
"SM": sm,
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(GemmSplitKDeclare,
|
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
|
||||||
value_dict)
|
|
||||||
|
|
||||||
all_code += CommonTail
|
all_code += CommonTail
|
||||||
return all_code
|
return all_code
|
||||||
@@ -473,9 +472,7 @@ def generate_launch_gemm_cus(
|
|||||||
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
|
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -489,9 +486,7 @@ def generate_launch_gemm_cus(
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -517,17 +512,14 @@ def generate_launch_gemm_cus(
|
|||||||
"num_stages": str(stage),
|
"num_stages": str(stage),
|
||||||
"SM": sm,
|
"SM": sm,
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
|
||||||
split_k_code += SubstituteTemplate(
|
|
||||||
LaunchGemmPart3, value_dict)
|
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
source_all_code += split_k_code
|
source_all_code += split_k_code
|
||||||
source_all_code += LaunchGemmPart4
|
source_all_code += LaunchGemmPart4
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
|
||||||
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
|
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
f.close()
|
f.close()
|
||||||
@@ -581,9 +573,7 @@ def generate_dispatch_gemm_cu(
|
|||||||
all_code += code_part4
|
all_code += code_part4
|
||||||
tile_id = 0
|
tile_id = 0
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -593,10 +583,12 @@ def generate_dispatch_gemm_cu(
|
|||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(code_part5, value_dict)
|
all_code += SubstituteTemplate(code_part5, value_dict)
|
||||||
tile_id += 1
|
tile_id += 1
|
||||||
value_dict.update({
|
value_dict.update(
|
||||||
"min_split_k": str(min_split_k),
|
{
|
||||||
"max_split_k": str(max_split_k),
|
"min_split_k": str(min_split_k),
|
||||||
})
|
"max_split_k": str(max_split_k),
|
||||||
|
}
|
||||||
|
)
|
||||||
all_code += SubstituteTemplate(code_part6, value_dict)
|
all_code += SubstituteTemplate(code_part6, value_dict)
|
||||||
return all_code
|
return all_code
|
||||||
|
|
||||||
@@ -614,9 +606,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for sm in archs:
|
for sm in archs:
|
||||||
if sm == "89":
|
if sm == "89":
|
||||||
fuse_gemm_configs = get_candidate_configs(sm, min_split_k,
|
fuse_gemm_configs = get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
|
||||||
max_split_k, min_stages,
|
|
||||||
max_stages)
|
|
||||||
for fuse_gemm_config in fuse_gemm_configs:
|
for fuse_gemm_config in fuse_gemm_configs:
|
||||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
|
||||||
all_code = generate_source_cu(
|
all_code = generate_source_cu(
|
||||||
@@ -654,9 +644,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# hard code for act_tag
|
# hard code for act_tag
|
||||||
|
|
||||||
file_name = (
|
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
|
||||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
|
|
||||||
)
|
|
||||||
all_code = generate_dispatch_gemm_cu(
|
all_code = generate_dispatch_gemm_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
outputs_type,
|
outputs_type,
|
||||||
|
@@ -20,44 +20,44 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
def get_candidate_tiles():
|
def get_candidate_tiles():
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
base_configs = [
|
base_configs = [
|
||||||
("<_64, _64, _128>", "<_1, _8, _1>"),
|
("<_64, _64, _128>", "<_1, _8, _1>"),
|
||||||
("<_64, _128, _128>", "<_2, _1, _1>"),
|
("<_64, _128, _128>", "<_2, _1, _1>"),
|
||||||
("<_128, _128, _128>", "<_2, _1, _1>"),
|
("<_128, _128, _128>", "<_2, _1, _1>"),
|
||||||
]
|
]
|
||||||
base_configs.extend([
|
base_configs.extend(
|
||||||
("<_64, _64, _128>", "<_1, _1, _1>"),
|
[
|
||||||
("<_64, _64, _128>", "<_1, _2, _1>"),
|
("<_64, _64, _128>", "<_1, _1, _1>"),
|
||||||
("<_64, _64, _128>", "<_2, _1, _1>"),
|
("<_64, _64, _128>", "<_1, _2, _1>"),
|
||||||
("<_64, _64, _64>", "<_1, _1, _1>"),
|
("<_64, _64, _128>", "<_2, _1, _1>"),
|
||||||
("<_64, _64, _64>", "<_1, _2, _1>"),
|
("<_64, _64, _64>", "<_1, _1, _1>"),
|
||||||
("<_64, _64, _64>", "<_2, _1, _1>"),
|
("<_64, _64, _64>", "<_1, _2, _1>"),
|
||||||
("<_64, _128, _128>", "<_1, _2, _1>"),
|
("<_64, _64, _64>", "<_2, _1, _1>"),
|
||||||
("<_64, _128, _128>", "<_1, _1, _1>"),
|
("<_64, _128, _128>", "<_1, _2, _1>"),
|
||||||
("<_128, _128, _64>", "<_2, _1, _1>"),
|
("<_64, _128, _128>", "<_1, _1, _1>"),
|
||||||
("<_256, _128, _128>", "<_1, _2, _1>"),
|
("<_128, _128, _64>", "<_2, _1, _1>"),
|
||||||
("<_256, _128, _128>", "<_1, _1, _1>"),
|
("<_256, _128, _128>", "<_1, _2, _1>"),
|
||||||
# The following configurations are rarely selected in Qwen2-7B-model.
|
("<_256, _128, _128>", "<_1, _1, _1>"),
|
||||||
# ("<_256, _128, _128>", "<_4, _1, _1>"),
|
# The following configurations are rarely selected in Qwen2-7B-model.
|
||||||
# ("<_256, _128, _128>", "<_1, _4, _1>"),
|
# ("<_256, _128, _128>", "<_4, _1, _1>"),
|
||||||
# ("<_256, _128, _128>", "<_2, _4, _1>"),
|
# ("<_256, _128, _128>", "<_1, _4, _1>"),
|
||||||
# ("<_128, _128, _256>", "<_1, _2, _1>"),
|
# ("<_256, _128, _128>", "<_2, _4, _1>"),
|
||||||
# ("<_128, _128, _128>", "<_4, _1, _1>"),
|
# ("<_128, _128, _256>", "<_1, _2, _1>"),
|
||||||
# ("<_128, _128, _128>", "<_2, _4, _1>"),
|
# ("<_128, _128, _128>", "<_4, _1, _1>"),
|
||||||
# ("<_128, _128, _128>", "<_1, _2, _1>"),
|
# ("<_128, _128, _128>", "<_2, _4, _1>"),
|
||||||
# ("<_128, _128, _128>", "<_1, _1, _1>"),
|
# ("<_128, _128, _128>", "<_1, _2, _1>"),
|
||||||
# ("<_128, _128, _128>", "<_1, _4, _1>"),
|
# ("<_128, _128, _128>", "<_1, _1, _1>"),
|
||||||
# ("<_128, _128, _64>", "<_2, _2, _1>"),
|
# ("<_128, _128, _128>", "<_1, _4, _1>"),
|
||||||
])
|
# ("<_128, _128, _64>", "<_2, _2, _1>"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return base_configs
|
return base_configs
|
||||||
|
|
||||||
|
|
||||||
def get_candidate_configs(sm):
|
def get_candidate_configs(sm):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
tiles = get_candidate_tiles()
|
tiles = get_candidate_tiles()
|
||||||
candidate_configs = list()
|
candidate_configs = list()
|
||||||
|
|
||||||
@@ -73,36 +73,31 @@ def get_candidate_configs(sm):
|
|||||||
("relu", "ReLu"),
|
("relu", "ReLu"),
|
||||||
("gelu", "GELU"),
|
("gelu", "GELU"),
|
||||||
]:
|
]:
|
||||||
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
|
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
|
||||||
EpilogueSchedule)])
|
|
||||||
|
|
||||||
return candidate_configs
|
return candidate_configs
|
||||||
|
|
||||||
|
|
||||||
def get_shape_str(tile_shape):
|
def get_shape_str(tile_shape):
|
||||||
"""
|
""" """
|
||||||
"""
|
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
|
||||||
blocks, clusters = [
|
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
|
|
||||||
]
|
|
||||||
blocks = [elem.strip("_") for elem in blocks]
|
blocks = [elem.strip("_") for elem in blocks]
|
||||||
clusters = [elem.strip("_") for elem in clusters]
|
clusters = [elem.strip("_") for elem in clusters]
|
||||||
return blocks, clusters
|
return blocks, clusters
|
||||||
|
|
||||||
|
|
||||||
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
blocks, clusters = get_shape_str(tile_shape)
|
blocks, clusters = get_shape_str(tile_shape)
|
||||||
if int(
|
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
||||||
blocks[0]
|
|
||||||
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
|
|
||||||
return False
|
return False
|
||||||
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
|
||||||
return False
|
return False
|
||||||
if (tile_shape[0] == "<_256, _128, _128>"
|
if (
|
||||||
and "Cooperative" not in kernel_schedule
|
tile_shape[0] == "<_256, _128, _128>"
|
||||||
and "Cooperative" not in epilogue_schedule):
|
and "Cooperative" not in kernel_schedule
|
||||||
|
and "Cooperative" not in epilogue_schedule
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -321,16 +316,12 @@ bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
|
|||||||
|
|
||||||
|
|
||||||
def SubstituteTemplate(template, values_base):
|
def SubstituteTemplate(template, values_base):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
values = copy.deepcopy(values_base)
|
values = copy.deepcopy(values_base)
|
||||||
if values.get("KernelSchedule"
|
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
|
||||||
) is not None and "Auto" in values["KernelSchedule"]:
|
|
||||||
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
|
||||||
if values.get("EpilogueSchedule"
|
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
|
||||||
) is not None and "Auto" in values["EpilogueSchedule"]:
|
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
||||||
values[
|
|
||||||
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
|
|
||||||
text = template
|
text = template
|
||||||
changed = True
|
changed = True
|
||||||
while changed:
|
while changed:
|
||||||
@@ -345,10 +336,8 @@ def SubstituteTemplate(template, values_base):
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
""" """
|
||||||
"""
|
parser = argparse.ArgumentParser(description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -363,17 +352,16 @@ def parse_args():
|
|||||||
|
|
||||||
# generate source .cu
|
# generate source .cu
|
||||||
def generate_source_cu(
|
def generate_source_cu(
|
||||||
inputs_type: (str),
|
inputs_type: str,
|
||||||
outputs_type: (str),
|
outputs_type: str,
|
||||||
hasbiases: (str),
|
hasbiases: str,
|
||||||
act_tag: (str),
|
act_tag: str,
|
||||||
tiles: (str),
|
tiles: str,
|
||||||
KernelSchedule: (str),
|
KernelSchedule: str,
|
||||||
EpilogueSchedule: (str),
|
EpilogueSchedule: str,
|
||||||
sm: str,
|
sm: str,
|
||||||
):
|
):
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
all_code = CommonHead
|
all_code = CommonHead
|
||||||
|
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
@@ -382,9 +370,7 @@ def generate_source_cu(
|
|||||||
for tile_config in tiles:
|
for tile_config in tiles:
|
||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_config,
|
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
|
||||||
kernel_schedule,
|
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"input_type": input_type,
|
"input_type": input_type,
|
||||||
@@ -398,25 +384,27 @@ def generate_source_cu(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
all_code += SubstituteTemplate(
|
all_code += SubstituteTemplate(GemmDeclare, value_dict)
|
||||||
GemmDeclare, value_dict)
|
|
||||||
|
|
||||||
return all_code
|
return all_code
|
||||||
|
|
||||||
|
|
||||||
# generate gemm launch .cu
|
# generate gemm launch .cu
|
||||||
def generate_launch_gemm_cus(
|
def generate_launch_gemm_cus(
|
||||||
generate_dir: (str), inputs_type: (str), outputs_type: (str),
|
generate_dir: str,
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
inputs_type: str,
|
||||||
"""
|
outputs_type: str,
|
||||||
"""
|
fuse_gemm_configs: tuple,
|
||||||
|
sm: str,
|
||||||
|
):
|
||||||
|
""" """
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
|
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
code_map = {}
|
code_map = {}
|
||||||
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
|
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
@@ -426,16 +414,14 @@ def generate_launch_gemm_cus(
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_config, kernel_schedule,
|
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
"gemm_config": gemm_config_str,
|
"gemm_config": gemm_config_str,
|
||||||
}
|
}
|
||||||
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
|
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
|
||||||
value_dict)
|
|
||||||
os.makedirs(generate_dir, exist_ok=True)
|
os.makedirs(generate_dir, exist_ok=True)
|
||||||
with open(head_path, "w") as f:
|
with open(head_path, "w") as f:
|
||||||
f.write(head_all_code)
|
f.write(head_all_code)
|
||||||
@@ -447,16 +433,14 @@ def generate_launch_gemm_cus(
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
"gemm_config": gemm_config_str,
|
"gemm_config": gemm_config_str,
|
||||||
}
|
}
|
||||||
source_all_code = SubstituteTemplate(LaunchGemmPart0,
|
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
|
||||||
value_dict)
|
|
||||||
type_id = 0
|
type_id = 0
|
||||||
for input_type in inputs_type:
|
for input_type in inputs_type:
|
||||||
for output_type in outputs_type:
|
for output_type in outputs_type:
|
||||||
@@ -475,14 +459,14 @@ def generate_launch_gemm_cus(
|
|||||||
"SM": sm,
|
"SM": sm,
|
||||||
"sm": sm[-2:],
|
"sm": sm[-2:],
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(
|
||||||
generate_dir,
|
generate_dir,
|
||||||
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu")
|
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
|
||||||
|
)
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
f.close()
|
f.close()
|
||||||
@@ -491,17 +475,15 @@ def generate_launch_gemm_cus(
|
|||||||
|
|
||||||
|
|
||||||
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
|
||||||
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
|
||||||
fuse_gemm_configs: tuple, sm: str):
|
""" """
|
||||||
"""
|
|
||||||
"""
|
|
||||||
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
|
||||||
|
|
||||||
single_config = fuse_gemm_configs[0]
|
single_config = fuse_gemm_configs[0]
|
||||||
hasbiases: (str) = single_config[0]
|
hasbiases: str = single_config[0]
|
||||||
tiles: (str) = single_config[2]
|
tiles: str = single_config[2]
|
||||||
KernelSchedule: (str) = single_config[3]
|
KernelSchedule: str = single_config[3]
|
||||||
EpilogueSchedule: (str) = single_config[4]
|
EpilogueSchedule: str = single_config[4]
|
||||||
|
|
||||||
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
|
||||||
type_id = 0
|
type_id = 0
|
||||||
@@ -524,8 +506,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
|||||||
for tile_shape in tiles:
|
for tile_shape in tiles:
|
||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
value_dict = {
|
value_dict = {
|
||||||
"TileShape": tile_shape[0],
|
"TileShape": tile_shape[0],
|
||||||
@@ -544,8 +525,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
|
|||||||
for kernel_schedule in KernelSchedule:
|
for kernel_schedule in KernelSchedule:
|
||||||
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
|
||||||
for epilogue_schedule in EpilogueSchedule:
|
for epilogue_schedule in EpilogueSchedule:
|
||||||
if not check_config_valid(tile_shape, kernel_schedule,
|
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
|
||||||
epilogue_schedule):
|
|
||||||
continue
|
continue
|
||||||
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
|
||||||
value_dict = {
|
value_dict = {
|
||||||
@@ -576,7 +556,8 @@ if __name__ == "__main__":
|
|||||||
for fuse_gemm_config in fuse_gemm_configs:
|
for fuse_gemm_config in fuse_gemm_configs:
|
||||||
file_name = (
|
file_name = (
|
||||||
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
|
||||||
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu")
|
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
|
||||||
|
)
|
||||||
all_code = generate_source_cu(
|
all_code = generate_source_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
outputs_type,
|
outputs_type,
|
||||||
@@ -594,8 +575,12 @@ if __name__ == "__main__":
|
|||||||
f.close()
|
f.close()
|
||||||
# Compile parallelization
|
# Compile parallelization
|
||||||
generate_launch_gemm_cus(
|
generate_launch_gemm_cus(
|
||||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
|
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
|
||||||
outputs_type, fuse_gemm_configs, sm_dict[sm])
|
inputs_type,
|
||||||
|
outputs_type,
|
||||||
|
fuse_gemm_configs,
|
||||||
|
sm_dict[sm],
|
||||||
|
)
|
||||||
|
|
||||||
# hard code for act_tag
|
# hard code for act_tag
|
||||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"
|
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"
|
||||||
|
@@ -30,22 +30,24 @@ def get_candidate_tiles():
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
|
||||||
base_configs.extend([
|
base_configs.extend(
|
||||||
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
[
|
||||||
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||||
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
|
||||||
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
|
||||||
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
|
||||||
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
|
||||||
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
|
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
|
||||||
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
|
||||||
])
|
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return base_configs
|
return base_configs
|
||||||
|
|
||||||
@@ -278,8 +280,7 @@ def parse_args():
|
|||||||
代码参数解析
|
代码参数解析
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=
|
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
||||||
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda_arch",
|
"--cuda_arch",
|
||||||
@@ -370,13 +371,10 @@ def generate_launch_gemm_cus(
|
|||||||
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典,格式为{"gemm_config": source_code}。
|
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典,格式为{"gemm_config": source_code}。
|
||||||
"""
|
"""
|
||||||
code_map = {}
|
code_map = {}
|
||||||
head_path = os.path.join(generate_dir,
|
head_path = os.path.join(generate_dir, "launch_visitor_gemm_fused_kernel.h")
|
||||||
"launch_visitor_gemm_fused_kernel.h")
|
|
||||||
head_all_code = LaunchGemmHead
|
head_all_code = LaunchGemmHead
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -390,9 +388,7 @@ def generate_launch_gemm_cus(
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -415,14 +411,14 @@ def generate_launch_gemm_cus(
|
|||||||
"num_stages": str(stage),
|
"num_stages": str(stage),
|
||||||
"SM": sm,
|
"SM": sm,
|
||||||
}
|
}
|
||||||
source_all_code += SubstituteTemplate(
|
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
|
||||||
LaunchGemmPart1, value_dict)
|
|
||||||
type_id += 1
|
type_id += 1
|
||||||
source_all_code += LaunchGemmPart2
|
source_all_code += LaunchGemmPart2
|
||||||
code_map[gemm_config_str] = source_all_code
|
code_map[gemm_config_str] = source_all_code
|
||||||
source_path = os.path.join(
|
source_path = os.path.join(
|
||||||
generate_dir,
|
generate_dir,
|
||||||
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu")
|
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu",
|
||||||
|
)
|
||||||
with open(source_path, "w") as f:
|
with open(source_path, "w") as f:
|
||||||
f.write(source_all_code)
|
f.write(source_all_code)
|
||||||
f.close()
|
f.close()
|
||||||
@@ -485,9 +481,7 @@ def generate_dispatch_gemm_cu(
|
|||||||
all_code += code_part4
|
all_code += code_part4
|
||||||
tile_id = 0
|
tile_id = 0
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
blocks, warps, mmas = [
|
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
|
||||||
s.replace(" ", "").strip("<>").split(",") for s in tile
|
|
||||||
]
|
|
||||||
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
|
||||||
for stage in stages:
|
for stage in stages:
|
||||||
gemm_config_str = gemm_config + f"_stage{stage}"
|
gemm_config_str = gemm_config + f"_stage{stage}"
|
||||||
@@ -512,10 +506,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for sm in archs:
|
for sm in archs:
|
||||||
if sm == "89":
|
if sm == "89":
|
||||||
fuse_gemm_configs = get_candidate_configs(sm, min_stages,
|
fuse_gemm_configs = get_candidate_configs(sm, min_stages, max_stages)
|
||||||
max_stages)
|
|
||||||
for fuse_gemm_config in fuse_gemm_configs:
|
for fuse_gemm_config in fuse_gemm_configs:
|
||||||
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
|
file_name = (
|
||||||
|
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
|
||||||
|
)
|
||||||
all_code = generate_source_cu(
|
all_code = generate_source_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
outputs_type,
|
outputs_type,
|
||||||
@@ -544,9 +539,7 @@ if __name__ == "__main__":
|
|||||||
sm_dict[sm],
|
sm_dict[sm],
|
||||||
)
|
)
|
||||||
|
|
||||||
file_name = (
|
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
|
||||||
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
|
|
||||||
)
|
|
||||||
all_code = generate_dispatch_gemm_cu(
|
all_code = generate_dispatch_gemm_cu(
|
||||||
inputs_type,
|
inputs_type,
|
||||||
outputs_type,
|
outputs_type,
|
||||||
|
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
vsl.kv_lod_vp = {
|
vsl.kv_lod_vp = {
|
||||||
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||||
enc_batch + 1, nullptr};
|
enc_batch + 1, nullptr};
|
||||||
|
|
||||||
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
|
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
|
||||||
nullptr,
|
nullptr,
|
||||||
0,
|
0,
|
||||||
|
@@ -30,8 +30,7 @@ current_file = Path(__file__).resolve()
|
|||||||
base_dir = current_file.parent
|
base_dir = current_file.parent
|
||||||
|
|
||||||
|
|
||||||
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
|
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
|
||||||
XDNN_LIB_DIR):
|
|
||||||
"""
|
"""
|
||||||
build xpu plugin
|
build xpu plugin
|
||||||
"""
|
"""
|
||||||
@@ -49,7 +48,10 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
|
|||||||
|
|
||||||
# 删除指定目录
|
# 删除指定目录
|
||||||
dirs_to_remove = [
|
dirs_to_remove = [
|
||||||
"dist", "fastdeploy_ops.egg-info", "build", "plugin/build"
|
"dist",
|
||||||
|
"fastdeploy_ops.egg-info",
|
||||||
|
"build",
|
||||||
|
"plugin/build",
|
||||||
]
|
]
|
||||||
for dir_name in dirs_to_remove:
|
for dir_name in dirs_to_remove:
|
||||||
if os.path.exists(dir_name):
|
if os.path.exists(dir_name):
|
||||||
@@ -58,8 +60,7 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
|
|||||||
|
|
||||||
# 在 plugin 目录中执行构建脚本
|
# 在 plugin 目录中执行构建脚本
|
||||||
plugin_dir = "plugin"
|
plugin_dir = "plugin"
|
||||||
build_script = os.path.join(current_working_directory, plugin_dir,
|
build_script = os.path.join(current_working_directory, plugin_dir, "build.sh")
|
||||||
"build.sh")
|
|
||||||
|
|
||||||
print("build_script: ", build_script)
|
print("build_script: ", build_script)
|
||||||
|
|
||||||
@@ -74,14 +75,16 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
|
|||||||
# 执行构建脚本
|
# 执行构建脚本
|
||||||
try:
|
try:
|
||||||
print("Running build script...")
|
print("Running build script...")
|
||||||
subprocess.run([build_script],
|
subprocess.run(
|
||||||
check=True,
|
[build_script],
|
||||||
cwd=os.path.join(current_working_directory, plugin_dir))
|
check=True,
|
||||||
|
cwd=os.path.join(current_working_directory, plugin_dir),
|
||||||
|
)
|
||||||
print("Build completed successfully.")
|
print("Build completed successfully.")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Build failed with error: {e}")
|
print(f"Build failed with error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Unexpected error: {str(e)}")
|
print(f"Unexpected error: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
def xpu_setup_ops():
|
def xpu_setup_ops():
|
||||||
@@ -124,17 +127,14 @@ def xpu_setup_ops():
|
|||||||
XVLLM_PATH = os.getenv("XVLLM_PATH")
|
XVLLM_PATH = os.getenv("XVLLM_PATH")
|
||||||
assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
|
assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
|
||||||
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
|
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
|
||||||
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so",
|
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", "libapiinfer.so")
|
||||||
"libapiinfer.so")
|
|
||||||
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
|
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
|
||||||
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
|
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
|
||||||
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so",
|
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", "libxft_blocks.so")
|
||||||
"libxft_blocks.so")
|
|
||||||
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
|
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
|
||||||
|
|
||||||
# build plugin
|
# build plugin
|
||||||
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH,
|
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
|
||||||
XDNN_LIB_DIR)
|
|
||||||
|
|
||||||
ops = [
|
ops = [
|
||||||
# custom ops
|
# custom ops
|
||||||
@@ -152,7 +152,6 @@ def xpu_setup_ops():
|
|||||||
"./ops/block_attn.cc",
|
"./ops/block_attn.cc",
|
||||||
"./ops/moe_layer.cc",
|
"./ops/moe_layer.cc",
|
||||||
"./ops/weight_quantize_xpu.cc",
|
"./ops/weight_quantize_xpu.cc",
|
||||||
|
|
||||||
# device manage ops
|
# device manage ops
|
||||||
"./ops/device/get_context_gm_max_mem_demand.cc",
|
"./ops/device/get_context_gm_max_mem_demand.cc",
|
||||||
"./ops/device/get_free_global_memory.cc",
|
"./ops/device/get_free_global_memory.cc",
|
||||||
|
@@ -29,7 +29,7 @@ for i in range(bs):
|
|||||||
ids_len = seq_lens[i, 0]
|
ids_len = seq_lens[i, 0]
|
||||||
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
|
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
|
||||||
|
|
||||||
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
|
(x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
|
||||||
paddle.to_tensor(input_ids),
|
paddle.to_tensor(input_ids),
|
||||||
paddle.to_tensor(cum_offset),
|
paddle.to_tensor(cum_offset),
|
||||||
paddle.to_tensor(token_num),
|
paddle.to_tensor(token_num),
|
||||||
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
|
|||||||
print("cu_seqlens_q:\n", cu_seqlens_q)
|
print("cu_seqlens_q:\n", cu_seqlens_q)
|
||||||
print("cu_seqlens_k:\n", cu_seqlens_k)
|
print("cu_seqlens_k:\n", cu_seqlens_k)
|
||||||
|
|
||||||
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6],
|
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
|
||||||
"int64")
|
|
||||||
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
|
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
|
||||||
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13],
|
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
|
||||||
"int32")
|
|
||||||
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
|
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
|
||||||
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
|
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
|
||||||
|
|
||||||
assert sum(ref_x_remove_padding -
|
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
|
||||||
x_remove_padding) == 0, 'Check x_remove_padding failed.'
|
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
|
||||||
assert sum(ref_cum_offsets_out -
|
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
|
||||||
cum_offsets_out) == 0, 'Check cum_offsets_out failed.'
|
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
|
||||||
assert sum(ref_padding_offset -
|
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
|
||||||
padding_offset) == 0, 'Check padding_offset failed.'
|
|
||||||
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
|
|
||||||
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'
|
|
||||||
|
@@ -21,10 +21,15 @@ paddle.seed(2023)
|
|||||||
|
|
||||||
pre_ids = paddle.to_tensor(
|
pre_ids = paddle.to_tensor(
|
||||||
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
||||||
"int64")
|
"int64",
|
||||||
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
|
)
|
||||||
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]],
|
logits = paddle.to_tensor(
|
||||||
"float32")
|
[
|
||||||
|
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
|
||||||
|
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
|
||||||
|
],
|
||||||
|
"float32",
|
||||||
|
)
|
||||||
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
||||||
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
|
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
|
||||||
presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
|
presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
|
||||||
@@ -88,78 +93,536 @@ ref_logits = np.array(
|
|||||||
)
|
)
|
||||||
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
||||||
print("diff_logits\n", diff_logits)
|
print("diff_logits\n", diff_logits)
|
||||||
assert diff_logits < 1e-6, 'Check failed.'
|
assert diff_logits < 1e-6, "Check failed."
|
||||||
|
|
||||||
pre_ids = paddle.to_tensor(
|
pre_ids = paddle.to_tensor(
|
||||||
[[
|
[
|
||||||
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8,
|
[
|
||||||
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4,
|
2,
|
||||||
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6,
|
3,
|
||||||
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8,
|
3,
|
||||||
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4,
|
5,
|
||||||
7, 5, 5, 7, 6, 9, 3, 9
|
8,
|
||||||
],
|
9,
|
||||||
[
|
3,
|
||||||
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1,
|
9,
|
||||||
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1,
|
1,
|
||||||
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8,
|
8,
|
||||||
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6,
|
9,
|
||||||
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6,
|
2,
|
||||||
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8
|
3,
|
||||||
]])
|
8,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
9,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
9,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
9,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
9,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
7,
|
||||||
|
4,
|
||||||
|
9,
|
||||||
|
9,
|
||||||
|
7,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
7,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
6,
|
||||||
|
9,
|
||||||
|
3,
|
||||||
|
9,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
7,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
7,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
5,
|
||||||
|
4,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
3,
|
||||||
|
9,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
6,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
7,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
4,
|
||||||
|
7,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
6,
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
9,
|
||||||
|
9,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
3,
|
||||||
|
7,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
6,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
6,
|
||||||
|
4,
|
||||||
|
7,
|
||||||
|
7,
|
||||||
|
3,
|
||||||
|
8,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
logits = paddle.to_tensor(
|
logits = paddle.to_tensor(
|
||||||
[[
|
[
|
||||||
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748,
|
[
|
||||||
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807,
|
0.16274983,
|
||||||
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218,
|
0.61470598,
|
||||||
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679,
|
0.94366980,
|
||||||
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888,
|
0.82005417,
|
||||||
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911,
|
0.50752640,
|
||||||
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448,
|
0.38316748,
|
||||||
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415,
|
0.92648441,
|
||||||
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104,
|
0.24050158,
|
||||||
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064,
|
0.05461595,
|
||||||
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672,
|
0.42218581,
|
||||||
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840,
|
0.36270225,
|
||||||
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613,
|
0.15464807,
|
||||||
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254,
|
0.13614719,
|
||||||
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189,
|
0.67509544,
|
||||||
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407,
|
0.40315166,
|
||||||
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923,
|
0.10671722,
|
||||||
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550,
|
0.24832056,
|
||||||
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222,
|
0.76091218,
|
||||||
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845,
|
0.11598995,
|
||||||
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070,
|
0.10962527,
|
||||||
0.98481447, 0.77367836
|
0.04688513,
|
||||||
],
|
0.81536716,
|
||||||
[
|
0.72259802,
|
||||||
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417,
|
0.60476679,
|
||||||
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683,
|
0.16701800,
|
||||||
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915,
|
0.84160781,
|
||||||
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503,
|
0.79649884,
|
||||||
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967,
|
0.78021604,
|
||||||
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547,
|
0.75329530,
|
||||||
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396,
|
0.98587888,
|
||||||
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451,
|
0.13421868,
|
||||||
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258,
|
0.16027625,
|
||||||
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694,
|
0.15269397,
|
||||||
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542,
|
0.06228730,
|
||||||
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293,
|
0.73856270,
|
||||||
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045,
|
0.34721911,
|
||||||
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974,
|
0.73683006,
|
||||||
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888,
|
0.78178608,
|
||||||
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515,
|
0.32068327,
|
||||||
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780,
|
0.79906309,
|
||||||
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544,
|
0.44214272,
|
||||||
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529,
|
0.63330448,
|
||||||
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276,
|
0.08016958,
|
||||||
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621,
|
0.63367140,
|
||||||
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523,
|
0.19788943,
|
||||||
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239,
|
0.55346787,
|
||||||
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721,
|
0.11142531,
|
||||||
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736,
|
0.90518415,
|
||||||
0.87317258, 0.88229448, 0.79217333
|
0.21236691,
|
||||||
]])
|
0.81587470,
|
||||||
|
0.83752930,
|
||||||
|
0.70979482,
|
||||||
|
0.35684183,
|
||||||
|
0.28715104,
|
||||||
|
0.87162822,
|
||||||
|
0.17679396,
|
||||||
|
0.98725849,
|
||||||
|
0.76129991,
|
||||||
|
0.04090235,
|
||||||
|
0.37181064,
|
||||||
|
0.63317049,
|
||||||
|
0.24689502,
|
||||||
|
0.21126501,
|
||||||
|
0.57617670,
|
||||||
|
0.74346697,
|
||||||
|
0.40613672,
|
||||||
|
0.56907010,
|
||||||
|
0.68556929,
|
||||||
|
0.29032683,
|
||||||
|
0.17866278,
|
||||||
|
0.35165095,
|
||||||
|
0.97015840,
|
||||||
|
0.70785582,
|
||||||
|
0.54259878,
|
||||||
|
0.14712237,
|
||||||
|
0.90483177,
|
||||||
|
0.02094105,
|
||||||
|
0.36411613,
|
||||||
|
0.02495066,
|
||||||
|
0.88874054,
|
||||||
|
0.88895452,
|
||||||
|
0.86216462,
|
||||||
|
0.58062190,
|
||||||
|
0.95583254,
|
||||||
|
0.20553111,
|
||||||
|
0.29870346,
|
||||||
|
0.69652933,
|
||||||
|
0.36861244,
|
||||||
|
0.85316223,
|
||||||
|
0.50240189,
|
||||||
|
0.17566244,
|
||||||
|
0.61080140,
|
||||||
|
0.88203174,
|
||||||
|
0.98675215,
|
||||||
|
0.24344546,
|
||||||
|
0.17213407,
|
||||||
|
0.78160852,
|
||||||
|
0.25165486,
|
||||||
|
0.48188508,
|
||||||
|
0.82812423,
|
||||||
|
0.10199814,
|
||||||
|
0.90475923,
|
||||||
|
0.66907483,
|
||||||
|
0.71910626,
|
||||||
|
0.40660757,
|
||||||
|
0.59460294,
|
||||||
|
0.70212913,
|
||||||
|
0.90841550,
|
||||||
|
0.00329034,
|
||||||
|
0.11290466,
|
||||||
|
0.89654654,
|
||||||
|
0.69114941,
|
||||||
|
0.29473618,
|
||||||
|
0.62027222,
|
||||||
|
0.37333879,
|
||||||
|
0.98911142,
|
||||||
|
0.46510187,
|
||||||
|
0.65914583,
|
||||||
|
0.73022646,
|
||||||
|
0.12790845,
|
||||||
|
0.12817244,
|
||||||
|
0.43015456,
|
||||||
|
0.75011456,
|
||||||
|
0.43562204,
|
||||||
|
0.48086026,
|
||||||
|
0.75587070,
|
||||||
|
0.98481447,
|
||||||
|
0.77367836,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.12336024,
|
||||||
|
0.74152875,
|
||||||
|
0.09191196,
|
||||||
|
0.99301219,
|
||||||
|
0.44764417,
|
||||||
|
0.01848883,
|
||||||
|
0.78326035,
|
||||||
|
0.99228370,
|
||||||
|
0.81447607,
|
||||||
|
0.02627683,
|
||||||
|
0.51033205,
|
||||||
|
0.98703283,
|
||||||
|
0.15247856,
|
||||||
|
0.77640921,
|
||||||
|
0.60799915,
|
||||||
|
0.87518770,
|
||||||
|
0.76818430,
|
||||||
|
0.86542630,
|
||||||
|
0.31795895,
|
||||||
|
0.04829503,
|
||||||
|
0.85567141,
|
||||||
|
0.30271924,
|
||||||
|
0.67515039,
|
||||||
|
0.59728831,
|
||||||
|
0.78710967,
|
||||||
|
0.75111693,
|
||||||
|
0.56837374,
|
||||||
|
0.49085775,
|
||||||
|
0.91510201,
|
||||||
|
0.59545547,
|
||||||
|
0.99482232,
|
||||||
|
0.59036905,
|
||||||
|
0.58267909,
|
||||||
|
0.28770933,
|
||||||
|
0.53237396,
|
||||||
|
0.95318258,
|
||||||
|
0.93987304,
|
||||||
|
0.61142951,
|
||||||
|
0.26737869,
|
||||||
|
0.52285451,
|
||||||
|
0.03479086,
|
||||||
|
0.61631846,
|
||||||
|
0.66777998,
|
||||||
|
0.15736090,
|
||||||
|
0.00447258,
|
||||||
|
0.37035006,
|
||||||
|
0.15281211,
|
||||||
|
0.95372260,
|
||||||
|
0.25963321,
|
||||||
|
0.61036694,
|
||||||
|
0.15020694,
|
||||||
|
0.19171195,
|
||||||
|
0.55252832,
|
||||||
|
0.00391038,
|
||||||
|
0.31052542,
|
||||||
|
0.96495175,
|
||||||
|
0.42586124,
|
||||||
|
0.05630261,
|
||||||
|
0.99728668,
|
||||||
|
0.01856293,
|
||||||
|
0.83201504,
|
||||||
|
0.10701843,
|
||||||
|
0.56434178,
|
||||||
|
0.38009524,
|
||||||
|
0.51095045,
|
||||||
|
0.13202040,
|
||||||
|
0.07133843,
|
||||||
|
0.75313550,
|
||||||
|
0.17111187,
|
||||||
|
0.80716974,
|
||||||
|
0.00172165,
|
||||||
|
0.83906764,
|
||||||
|
0.73240769,
|
||||||
|
0.85843354,
|
||||||
|
0.11042888,
|
||||||
|
0.07912333,
|
||||||
|
0.33689004,
|
||||||
|
0.22334915,
|
||||||
|
0.59059596,
|
||||||
|
0.52789515,
|
||||||
|
0.29831955,
|
||||||
|
0.39515004,
|
||||||
|
0.55602801,
|
||||||
|
0.83818001,
|
||||||
|
0.05865780,
|
||||||
|
0.25654668,
|
||||||
|
0.76624149,
|
||||||
|
0.35190639,
|
||||||
|
0.04158346,
|
||||||
|
0.59157544,
|
||||||
|
0.30779791,
|
||||||
|
0.94609004,
|
||||||
|
0.10759670,
|
||||||
|
0.65575141,
|
||||||
|
0.37828529,
|
||||||
|
0.29571742,
|
||||||
|
0.76361233,
|
||||||
|
0.72476572,
|
||||||
|
0.18568406,
|
||||||
|
0.85430276,
|
||||||
|
0.02057583,
|
||||||
|
0.76195669,
|
||||||
|
0.65507215,
|
||||||
|
0.69129735,
|
||||||
|
0.25084621,
|
||||||
|
0.75223947,
|
||||||
|
0.06064088,
|
||||||
|
0.20287007,
|
||||||
|
0.35887691,
|
||||||
|
0.75043523,
|
||||||
|
0.47575447,
|
||||||
|
0.40021798,
|
||||||
|
0.44464844,
|
||||||
|
0.67975360,
|
||||||
|
0.40443239,
|
||||||
|
0.71052992,
|
||||||
|
0.21782248,
|
||||||
|
0.50568426,
|
||||||
|
0.89037591,
|
||||||
|
0.06661721,
|
||||||
|
0.28788096,
|
||||||
|
0.70773387,
|
||||||
|
0.42428264,
|
||||||
|
0.80419677,
|
||||||
|
0.42710736,
|
||||||
|
0.87317258,
|
||||||
|
0.88229448,
|
||||||
|
0.79217333,
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
||||||
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
|
||||||
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
|
||||||
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
|
|||||||
print("eos_token_id\n", eos_token_id)
|
print("eos_token_id\n", eos_token_id)
|
||||||
|
|
||||||
ref_logits = np.array(
|
ref_logits = np.array(
|
||||||
[[
|
[
|
||||||
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280,
|
[
|
||||||
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450,
|
-10000000000.0,
|
||||||
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113,
|
-10000000000.0,
|
||||||
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603,
|
1.88733959,
|
||||||
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060,
|
1.64010835,
|
||||||
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541,
|
1.01505280,
|
||||||
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545,
|
0.76633495,
|
||||||
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062,
|
1.85296881,
|
||||||
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366,
|
0.48100317,
|
||||||
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470,
|
0.10923190,
|
||||||
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395,
|
0.84437162,
|
||||||
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191,
|
0.72540450,
|
||||||
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210,
|
0.30929613,
|
||||||
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380,
|
0.27229437,
|
||||||
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446,
|
1.35019088,
|
||||||
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091,
|
0.80630332,
|
||||||
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629,
|
0.21343444,
|
||||||
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825,
|
0.49664113,
|
||||||
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235,
|
1.52182436,
|
||||||
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291,
|
0.23197991,
|
||||||
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053,
|
0.21925054,
|
||||||
1.51174140, 1.96962893, 1.54735672
|
0.09377026,
|
||||||
|
1.63073432,
|
||||||
|
1.44519603,
|
||||||
|
1.20953357,
|
||||||
|
0.33403599,
|
||||||
|
1.68321562,
|
||||||
|
1.59299767,
|
||||||
|
1.56043208,
|
||||||
|
1.50659060,
|
||||||
|
1.97175777,
|
||||||
|
0.26843736,
|
||||||
|
0.32055250,
|
||||||
|
0.30538794,
|
||||||
|
0.12457460,
|
||||||
|
1.47712541,
|
||||||
|
0.69443822,
|
||||||
|
1.47366011,
|
||||||
|
1.56357217,
|
||||||
|
0.64136654,
|
||||||
|
1.59812617,
|
||||||
|
0.88428545,
|
||||||
|
1.26660895,
|
||||||
|
0.16033916,
|
||||||
|
1.26734281,
|
||||||
|
0.39577886,
|
||||||
|
1.10693574,
|
||||||
|
0.22285062,
|
||||||
|
1.81036830,
|
||||||
|
0.42473382,
|
||||||
|
1.63174939,
|
||||||
|
1.67505860,
|
||||||
|
1.41958964,
|
||||||
|
0.71368366,
|
||||||
|
0.57430208,
|
||||||
|
1.74325645,
|
||||||
|
0.35358793,
|
||||||
|
1.97451699,
|
||||||
|
1.52259982,
|
||||||
|
0.08180470,
|
||||||
|
0.74362129,
|
||||||
|
1.26634097,
|
||||||
|
0.49379003,
|
||||||
|
0.42253003,
|
||||||
|
1.15235341,
|
||||||
|
1.48693395,
|
||||||
|
0.81227344,
|
||||||
|
1.13814020,
|
||||||
|
1.37113857,
|
||||||
|
0.58065367,
|
||||||
|
0.35732555,
|
||||||
|
0.70330191,
|
||||||
|
1.94031680,
|
||||||
|
1.41571164,
|
||||||
|
1.08519757,
|
||||||
|
0.29424474,
|
||||||
|
1.80966353,
|
||||||
|
0.04188210,
|
||||||
|
0.72823226,
|
||||||
|
0.04990132,
|
||||||
|
1.77748108,
|
||||||
|
1.77790904,
|
||||||
|
1.72432923,
|
||||||
|
1.16124380,
|
||||||
|
1.91166508,
|
||||||
|
0.41106221,
|
||||||
|
0.59740692,
|
||||||
|
1.39305866,
|
||||||
|
0.73722488,
|
||||||
|
1.70632446,
|
||||||
|
1.00480378,
|
||||||
|
0.35132489,
|
||||||
|
1.22160280,
|
||||||
|
1.76406348,
|
||||||
|
1.97350430,
|
||||||
|
0.48689091,
|
||||||
|
0.34426814,
|
||||||
|
1.56321704,
|
||||||
|
0.50330973,
|
||||||
|
0.96377015,
|
||||||
|
1.65624845,
|
||||||
|
0.20399629,
|
||||||
|
1.80951846,
|
||||||
|
1.33814967,
|
||||||
|
1.43821251,
|
||||||
|
0.81321514,
|
||||||
|
1.18920588,
|
||||||
|
1.40425825,
|
||||||
|
1.81683099,
|
||||||
|
0.00658068,
|
||||||
|
0.22580932,
|
||||||
|
1.79309309,
|
||||||
|
1.38229883,
|
||||||
|
0.58947235,
|
||||||
|
1.24054444,
|
||||||
|
0.74667758,
|
||||||
|
1.97822285,
|
||||||
|
0.93020374,
|
||||||
|
1.31829166,
|
||||||
|
1.46045291,
|
||||||
|
0.25581691,
|
||||||
|
0.25634488,
|
||||||
|
0.86030912,
|
||||||
|
1.50022912,
|
||||||
|
0.87124407,
|
||||||
|
0.96172053,
|
||||||
|
1.51174140,
|
||||||
|
1.96962893,
|
||||||
|
1.54735672,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-10000000000.0,
|
||||||
|
-10000000000.0,
|
||||||
|
-40000.0,
|
||||||
|
3.97204876,
|
||||||
|
1.79057670,
|
||||||
|
0.07395532,
|
||||||
|
3.13304138,
|
||||||
|
3.96913481,
|
||||||
|
3.25790429,
|
||||||
|
-40000.0,
|
||||||
|
2.04132819,
|
||||||
|
3.94813132,
|
||||||
|
0.60991424,
|
||||||
|
3.10563684,
|
||||||
|
2.43199658,
|
||||||
|
3.50075078,
|
||||||
|
3.07273722,
|
||||||
|
3.46170521,
|
||||||
|
1.27183580,
|
||||||
|
0.19318011,
|
||||||
|
3.42268562,
|
||||||
|
1.21087694,
|
||||||
|
2.70060158,
|
||||||
|
2.38915324,
|
||||||
|
3.14843869,
|
||||||
|
3.00446773,
|
||||||
|
2.27349496,
|
||||||
|
1.96343100,
|
||||||
|
3.66040802,
|
||||||
|
2.38182187,
|
||||||
|
3.97928929,
|
||||||
|
2.36147618,
|
||||||
|
2.33071637,
|
||||||
|
1.15083730,
|
||||||
|
2.12949586,
|
||||||
|
3.81273031,
|
||||||
|
3.75949216,
|
||||||
|
2.44571805,
|
||||||
|
1.06951475,
|
||||||
|
2.09141803,
|
||||||
|
0.13916343,
|
||||||
|
2.46527386,
|
||||||
|
2.67111993,
|
||||||
|
0.62944359,
|
||||||
|
0.01789032,
|
||||||
|
1.48140025,
|
||||||
|
0.61124843,
|
||||||
|
3.81489038,
|
||||||
|
1.03853285,
|
||||||
|
2.44146776,
|
||||||
|
0.60082775,
|
||||||
|
0.76684779,
|
||||||
|
2.21011329,
|
||||||
|
0.01564152,
|
||||||
|
1.24210167,
|
||||||
|
3.85980701,
|
||||||
|
1.70344496,
|
||||||
|
0.22521044,
|
||||||
|
3.98914671,
|
||||||
|
0.07425172,
|
||||||
|
3.32806015,
|
||||||
|
0.42807373,
|
||||||
|
2.25736713,
|
||||||
|
1.52038097,
|
||||||
|
2.04380178,
|
||||||
|
0.52808160,
|
||||||
|
0.28535372,
|
||||||
|
3.01254201,
|
||||||
|
0.68444747,
|
||||||
|
3.22867894,
|
||||||
|
0.00688660,
|
||||||
|
3.35627055,
|
||||||
|
2.92963076,
|
||||||
|
3.43373418,
|
||||||
|
0.44171551,
|
||||||
|
0.31649333,
|
||||||
|
1.34756017,
|
||||||
|
0.89339662,
|
||||||
|
2.36238384,
|
||||||
|
2.11158061,
|
||||||
|
1.19327819,
|
||||||
|
1.58060014,
|
||||||
|
2.22411203,
|
||||||
|
3.35272002,
|
||||||
|
0.23463120,
|
||||||
|
1.02618670,
|
||||||
|
3.06496596,
|
||||||
|
1.40762556,
|
||||||
|
0.16633384,
|
||||||
|
2.36630177,
|
||||||
|
1.23119164,
|
||||||
|
3.78436017,
|
||||||
|
0.43038681,
|
||||||
|
2.62300563,
|
||||||
|
1.51314116,
|
||||||
|
1.18286967,
|
||||||
|
3.05444932,
|
||||||
|
2.89906287,
|
||||||
|
0.74273622,
|
||||||
|
3.41721106,
|
||||||
|
0.08230332,
|
||||||
|
3.04782677,
|
||||||
|
2.62028861,
|
||||||
|
2.76518941,
|
||||||
|
1.00338483,
|
||||||
|
3.00895786,
|
||||||
|
0.24256352,
|
||||||
|
0.81148028,
|
||||||
|
1.43550766,
|
||||||
|
3.00174093,
|
||||||
|
1.90301788,
|
||||||
|
1.60087192,
|
||||||
|
1.77859378,
|
||||||
|
2.71901441,
|
||||||
|
1.61772954,
|
||||||
|
2.84211969,
|
||||||
|
0.87128991,
|
||||||
|
2.02273703,
|
||||||
|
3.56150365,
|
||||||
|
0.26646885,
|
||||||
|
1.15152383,
|
||||||
|
2.83093548,
|
||||||
|
1.69713056,
|
||||||
|
3.21678710,
|
||||||
|
1.70842946,
|
||||||
|
3.49269032,
|
||||||
|
3.52917790,
|
||||||
|
3.16869330,
|
||||||
|
],
|
||||||
],
|
],
|
||||||
[
|
|
||||||
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670,
|
|
||||||
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819,
|
|
||||||
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078,
|
|
||||||
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562,
|
|
||||||
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773,
|
|
||||||
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929,
|
|
||||||
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031,
|
|
||||||
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343,
|
|
||||||
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025,
|
|
||||||
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775,
|
|
||||||
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701,
|
|
||||||
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015,
|
|
||||||
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160,
|
|
||||||
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660,
|
|
||||||
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333,
|
|
||||||
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819,
|
|
||||||
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670,
|
|
||||||
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164,
|
|
||||||
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967,
|
|
||||||
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332,
|
|
||||||
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786,
|
|
||||||
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788,
|
|
||||||
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969,
|
|
||||||
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383,
|
|
||||||
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032,
|
|
||||||
3.52917790, 3.16869330
|
|
||||||
]],
|
|
||||||
"float32",
|
"float32",
|
||||||
)
|
)
|
||||||
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
|
||||||
print("diff_logits\n", diff_logits)
|
print("diff_logits\n", diff_logits)
|
||||||
assert diff_logits < 1e-6, 'Check failed.'
|
assert diff_logits < 1e-6, "Check failed."
|
||||||
|
@@ -21,19 +21,30 @@ paddle.seed(2023)
|
|||||||
|
|
||||||
pre_ids_all = paddle.to_tensor(
|
pre_ids_all = paddle.to_tensor(
|
||||||
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
|
||||||
"int64")
|
"int64",
|
||||||
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
|
)
|
||||||
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]],
|
input_ids = paddle.to_tensor(
|
||||||
"int64")
|
[
|
||||||
|
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
|
||||||
|
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
|
||||||
|
],
|
||||||
|
"int64",
|
||||||
|
)
|
||||||
seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
|
seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
|
||||||
seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
|
seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
|
||||||
seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
|
seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
|
||||||
step_idx = paddle.to_tensor([1, 1], "int64")
|
step_idx = paddle.to_tensor([1, 1], "int64")
|
||||||
stop_flags = paddle.to_tensor([0, 1], "bool")
|
stop_flags = paddle.to_tensor([0, 1], "bool")
|
||||||
print("pre_ids_all\n", pre_ids_all)
|
print("pre_ids_all\n", pre_ids_all)
|
||||||
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time,
|
set_value_by_flags_and_idx(
|
||||||
seq_lens_encoder, seq_lens_decoder, step_idx,
|
pre_ids_all,
|
||||||
stop_flags)
|
input_ids,
|
||||||
|
seq_lens_this_time,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
step_idx,
|
||||||
|
stop_flags,
|
||||||
|
)
|
||||||
print("pre_ids_all\n", pre_ids_all)
|
print("pre_ids_all\n", pre_ids_all)
|
||||||
print("input_ids\n", input_ids)
|
print("input_ids\n", input_ids)
|
||||||
print("seq_lens_this_time\n", seq_lens_this_time)
|
print("seq_lens_this_time\n", seq_lens_this_time)
|
||||||
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
|
|||||||
)
|
)
|
||||||
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
|
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
|
||||||
print("diff_pre_ids_all\n", diff_pre_ids_all)
|
print("diff_pre_ids_all\n", diff_pre_ids_all)
|
||||||
assert diff_pre_ids_all == 0, 'Check failed.'
|
assert diff_pre_ids_all == 0, "Check failed."
|
||||||
|
@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
|
|||||||
max_block_num = block_bs * max_seq_len // block_size
|
max_block_num = block_bs * max_seq_len // block_size
|
||||||
free_list_len = int(max_block_num * (1 - block_ratio))
|
free_list_len = int(max_block_num * (1 - block_ratio))
|
||||||
free_list_len = np.full([1], free_list_len, "int32")
|
free_list_len = np.full([1], free_list_len, "int32")
|
||||||
free_list = np.arange(max_block_num - 1,
|
free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
|
||||||
max_block_num - free_list_len - 1,
|
|
||||||
-1,
|
|
||||||
dtype="int32")
|
|
||||||
|
|
||||||
encoder_block_lens = np.zeros([max_bs], "int32")
|
encoder_block_lens = np.zeros([max_bs], "int32")
|
||||||
used_list_len = np.zeros([max_bs], "int32")
|
used_list_len = np.zeros([max_bs], "int32")
|
||||||
@@ -53,19 +50,15 @@ encoder_block_id = 0
|
|||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
|
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
|
||||||
encoder_block_lens[i] = enc_block_num
|
encoder_block_lens[i] = enc_block_num
|
||||||
dec_block_num = (seq_lens_decoder[i] + block_size -
|
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
|
||||||
1) // block_size - enc_block_num
|
|
||||||
used_list_len[i] = dec_block_num
|
used_list_len[i] = dec_block_num
|
||||||
block_tables[i, :enc_block_num] = np.arange(
|
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
|
||||||
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
|
|
||||||
encoder_block_id += enc_block_num
|
encoder_block_id += enc_block_num
|
||||||
if dec_block_num > 0:
|
if dec_block_num > 0:
|
||||||
block_tables[
|
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
|
||||||
i, enc_block_num:enc_block_num +
|
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
|
||||||
dec_block_num] = free_list[free_list_len[0] - 1 -
|
]
|
||||||
dec_block_num:free_list_len[0] - 1]
|
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
|
||||||
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
|
|
||||||
1] = -1
|
|
||||||
free_list_len[0] -= dec_block_num
|
free_list_len[0] -= dec_block_num
|
||||||
assert free_list_len[0] >= 0
|
assert free_list_len[0] >= 0
|
||||||
|
|
||||||
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
|
|||||||
# print("step_idx: ", step_idx)
|
# print("step_idx: ", step_idx)
|
||||||
# print("next_tokens: ", next_tokens)
|
# print("next_tokens: ", next_tokens)
|
||||||
|
|
||||||
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
|
step_paddle(
|
||||||
seq_lens_encoder, seq_lens_decoder, block_tables,
|
stop_flags,
|
||||||
encoder_block_lens, is_block_step, step_block_list, step_lens,
|
seq_lens_this_time,
|
||||||
recover_block_list, recover_lens, need_block_list, need_block_len,
|
ori_seq_lens_encoder,
|
||||||
used_list_len, free_list, free_list_len, input_ids, pre_ids,
|
seq_lens_encoder,
|
||||||
step_idx, next_tokens, first_token_ids, block_size,
|
seq_lens_decoder,
|
||||||
encoder_decoder_block_num)
|
block_tables,
|
||||||
|
encoder_block_lens,
|
||||||
|
is_block_step,
|
||||||
|
step_block_list,
|
||||||
|
step_lens,
|
||||||
|
recover_block_list,
|
||||||
|
recover_lens,
|
||||||
|
need_block_list,
|
||||||
|
need_block_len,
|
||||||
|
used_list_len,
|
||||||
|
free_list,
|
||||||
|
free_list_len,
|
||||||
|
input_ids,
|
||||||
|
pre_ids,
|
||||||
|
step_idx,
|
||||||
|
next_tokens,
|
||||||
|
first_token_ids,
|
||||||
|
block_size,
|
||||||
|
encoder_decoder_block_num,
|
||||||
|
)
|
||||||
|
|
||||||
print("-" * 50 + "after step op" + "-" * 50)
|
print("-" * 50 + "after step op" + "-" * 50)
|
||||||
print("stop_flags: ", stop_flags)
|
print("stop_flags: ", stop_flags)
|
||||||
|
@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
|
|||||||
print("topk_ids\n", topk_ids)
|
print("topk_ids\n", topk_ids)
|
||||||
print("next_tokens\n", next_tokens)
|
print("next_tokens\n", next_tokens)
|
||||||
print("stop_flags\n", stop_flags)
|
print("stop_flags\n", stop_flags)
|
||||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
|
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
|
||||||
False)
|
|
||||||
print("topk_ids\n", topk_ids)
|
print("topk_ids\n", topk_ids)
|
||||||
print("next_tokens\n", next_tokens)
|
print("next_tokens\n", next_tokens)
|
||||||
print("stop_flags\n", stop_flags)
|
print("stop_flags\n", stop_flags)
|
||||||
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
|
|||||||
|
|
||||||
ref_topk_ids = np.array(
|
ref_topk_ids = np.array(
|
||||||
[
|
[
|
||||||
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20,
|
0,
|
||||||
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39,
|
0,
|
||||||
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59,
|
2,
|
||||||
60, 0, 0, 63
|
3,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
0,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
0,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
0,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
-1,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
46,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
0,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
57,
|
||||||
|
-1,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
63,
|
||||||
],
|
],
|
||||||
"int64",
|
"int64",
|
||||||
)
|
)
|
||||||
ref_next_tokens = np.array(
|
ref_next_tokens = np.array(
|
||||||
[
|
[
|
||||||
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20,
|
0,
|
||||||
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0,
|
0,
|
||||||
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0,
|
2,
|
||||||
0, 63
|
3,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
0,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
0,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
0,
|
||||||
|
25,
|
||||||
|
26,
|
||||||
|
27,
|
||||||
|
0,
|
||||||
|
29,
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
37,
|
||||||
|
38,
|
||||||
|
39,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
46,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
49,
|
||||||
|
50,
|
||||||
|
0,
|
||||||
|
52,
|
||||||
|
53,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
57,
|
||||||
|
0,
|
||||||
|
59,
|
||||||
|
60,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
63,
|
||||||
],
|
],
|
||||||
"int64",
|
"int64",
|
||||||
)
|
)
|
||||||
ref_stop_flags = np.array(
|
ref_stop_flags = np.array(
|
||||||
[
|
[
|
||||||
True, True, True, True, True, True, True, True, True, False, False,
|
True,
|
||||||
True, False, True, True, False, False, True, False, False, False, True,
|
True,
|
||||||
False, False, True, False, False, False, True, False, False, False,
|
True,
|
||||||
True, True, True, True, True, False, False, False, True, True, True,
|
True,
|
||||||
True, True, True, False, True, True, False, False, True, False, False,
|
True,
|
||||||
True, True, True, False, True, False, False, True, True, False
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
],
|
],
|
||||||
"bool",
|
"bool",
|
||||||
)
|
)
|
||||||
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
||||||
print("diff_topk_ids\n", diff_topk_ids)
|
print("diff_topk_ids\n", diff_topk_ids)
|
||||||
assert diff_topk_ids == 0, 'Check failed.'
|
assert diff_topk_ids == 0, "Check failed."
|
||||||
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
||||||
print("diff_next_tokens\n", diff_next_tokens)
|
print("diff_next_tokens\n", diff_next_tokens)
|
||||||
assert diff_next_tokens == 0, 'Check failed.'
|
assert diff_next_tokens == 0, "Check failed."
|
||||||
diff_stop_flags = np.sum(
|
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||||
np.abs(
|
|
||||||
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
|
||||||
print("diff_stop_flags\n", diff_stop_flags)
|
print("diff_stop_flags\n", diff_stop_flags)
|
||||||
assert diff_stop_flags == 0, 'Check failed.'
|
assert diff_stop_flags == 0, "Check failed."
|
||||||
|
|
||||||
# test beam_search=True
|
# test beam_search=True
|
||||||
topk_ids = paddle.arange(0, bs, dtype="int64")
|
topk_ids = paddle.arange(0, bs, dtype="int64")
|
||||||
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
|
|||||||
print("topk_ids\n", topk_ids)
|
print("topk_ids\n", topk_ids)
|
||||||
print("next_tokens\n", next_tokens)
|
print("next_tokens\n", next_tokens)
|
||||||
print("stop_flags\n", stop_flags)
|
print("stop_flags\n", stop_flags)
|
||||||
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
|
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
|
||||||
True)
|
|
||||||
print("topk_ids\n", topk_ids)
|
print("topk_ids\n", topk_ids)
|
||||||
print("next_tokens\n", next_tokens)
|
print("next_tokens\n", next_tokens)
|
||||||
print("stop_flags\n", stop_flags)
|
print("stop_flags\n", stop_flags)
|
||||||
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
|
|||||||
|
|
||||||
ref_topk_ids = np.array(
|
ref_topk_ids = np.array(
|
||||||
[
|
[
|
||||||
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19,
|
0,
|
||||||
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0,
|
1,
|
||||||
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58,
|
2,
|
||||||
-1, 60, 61, -1, 63
|
3,
|
||||||
|
4,
|
||||||
|
0,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
-1,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
0,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
0,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
0,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
41,
|
||||||
|
-1,
|
||||||
|
0,
|
||||||
|
44,
|
||||||
|
45,
|
||||||
|
46,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
49,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
53,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
58,
|
||||||
|
-1,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
-1,
|
||||||
|
63,
|
||||||
],
|
],
|
||||||
"int64",
|
"int64",
|
||||||
)
|
)
|
||||||
ref_next_tokens = np.array(
|
ref_next_tokens = np.array(
|
||||||
[
|
[
|
||||||
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20,
|
0,
|
||||||
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0,
|
1,
|
||||||
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61,
|
2,
|
||||||
0, 63
|
3,
|
||||||
|
4,
|
||||||
|
0,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
0,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
0,
|
||||||
|
17,
|
||||||
|
18,
|
||||||
|
19,
|
||||||
|
20,
|
||||||
|
0,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
28,
|
||||||
|
29,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
35,
|
||||||
|
36,
|
||||||
|
37,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
41,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
44,
|
||||||
|
45,
|
||||||
|
46,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
49,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
53,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
58,
|
||||||
|
0,
|
||||||
|
60,
|
||||||
|
61,
|
||||||
|
0,
|
||||||
|
63,
|
||||||
],
|
],
|
||||||
"int64",
|
"int64",
|
||||||
)
|
)
|
||||||
ref_stop_flags = np.array(
|
ref_stop_flags = np.array(
|
||||||
[
|
[
|
||||||
False, False, False, False, False, True, False, False, True, False,
|
False,
|
||||||
False, True, True, False, False, False, True, False, False, False,
|
False,
|
||||||
False, True, False, False, False, False, True, True, False, False,
|
False,
|
||||||
True, True, True, False, False, False, False, False, True, True, True,
|
False,
|
||||||
False, True, True, False, False, False, True, True, False, True, True,
|
False,
|
||||||
True, False, True, True, True, True, False, True, False, False, True,
|
True,
|
||||||
False
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
],
|
],
|
||||||
"bool",
|
"bool",
|
||||||
)
|
)
|
||||||
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
|
||||||
print("diff_topk_ids\n", diff_topk_ids)
|
print("diff_topk_ids\n", diff_topk_ids)
|
||||||
assert diff_topk_ids == 0, 'Check failed.'
|
assert diff_topk_ids == 0, "Check failed."
|
||||||
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
|
||||||
print("diff_next_tokens\n", diff_next_tokens)
|
print("diff_next_tokens\n", diff_next_tokens)
|
||||||
assert diff_next_tokens == 0, 'Check failed.'
|
assert diff_next_tokens == 0, "Check failed."
|
||||||
diff_stop_flags = np.sum(
|
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
||||||
np.abs(
|
|
||||||
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
|
|
||||||
print("diff_stop_flags\n", diff_stop_flags)
|
print("diff_stop_flags\n", diff_stop_flags)
|
||||||
assert diff_stop_flags == 0, 'Check failed.'
|
assert diff_stop_flags == 0, "Check failed."
|
||||||
|
@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
|
|||||||
print("next_tokens:\n", next_tokens)
|
print("next_tokens:\n", next_tokens)
|
||||||
print("is_block_step:\n", is_block_step)
|
print("is_block_step:\n", is_block_step)
|
||||||
|
|
||||||
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder,
|
update_inputs(
|
||||||
seq_lens_decoder, input_ids, stop_nums, next_tokens,
|
stop_flags,
|
||||||
is_block_step)
|
not_need_stop,
|
||||||
|
seq_lens_this_time,
|
||||||
|
seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
input_ids,
|
||||||
|
stop_nums,
|
||||||
|
next_tokens,
|
||||||
|
is_block_step,
|
||||||
|
)
|
||||||
|
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
print("stop_flags:\n", stop_flags)
|
print("stop_flags:\n", stop_flags)
|
||||||
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
|
|||||||
print("next_tokens:\n", next_tokens)
|
print("next_tokens:\n", next_tokens)
|
||||||
|
|
||||||
ref_not_need_stop_out = np.array([True])
|
ref_not_need_stop_out = np.array([True])
|
||||||
ref_seq_lens_this_time_out = np.array([
|
ref_seq_lens_this_time_out = np.array(
|
||||||
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
|
[
|
||||||
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1
|
0,
|
||||||
], "int32")
|
0,
|
||||||
ref_seq_lens_encoder_out = np.array([
|
1,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
1,
|
||||||
], "int32")
|
0,
|
||||||
ref_seq_lens_decoder_out = np.array([
|
1,
|
||||||
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0,
|
1,
|
||||||
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0,
|
1,
|
||||||
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
0,
|
||||||
], "int32")
|
1,
|
||||||
input_ids_np[:, 0] = np.array([
|
1,
|
||||||
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5,
|
0,
|
||||||
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7,
|
0,
|
||||||
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8
|
0,
|
||||||
], "int64")
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
"int32",
|
||||||
|
)
|
||||||
|
ref_seq_lens_encoder_out = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
"int32",
|
||||||
|
)
|
||||||
|
ref_seq_lens_decoder_out = np.array(
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
6,
|
||||||
|
0,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
10,
|
||||||
|
0,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
20,
|
||||||
|
22,
|
||||||
|
0,
|
||||||
|
24,
|
||||||
|
24,
|
||||||
|
0,
|
||||||
|
26,
|
||||||
|
28,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
32,
|
||||||
|
32,
|
||||||
|
0,
|
||||||
|
34,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
38,
|
||||||
|
0,
|
||||||
|
40,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
42,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
46,
|
||||||
|
46,
|
||||||
|
48,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
],
|
||||||
|
"int32",
|
||||||
|
)
|
||||||
|
input_ids_np[:, 0] = np.array(
|
||||||
|
[
|
||||||
|
6,
|
||||||
|
5,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
1,
|
||||||
|
9,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
6,
|
||||||
|
5,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
9,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
3,
|
||||||
|
9,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
8,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
9,
|
||||||
|
7,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
8,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
],
|
||||||
|
"int64",
|
||||||
|
)
|
||||||
|
|
||||||
assert not_need_stop.numpy(
|
assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
|
||||||
) == ref_not_need_stop_out, 'Check not_need_stop failed.'
|
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
|
||||||
assert np.all(seq_lens_this_time.numpy() ==
|
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
|
||||||
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.'
|
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
|
||||||
assert np.all(seq_lens_encoder.numpy() ==
|
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
|
||||||
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
|
|
||||||
assert np.all(seq_lens_decoder.numpy() ==
|
|
||||||
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
|
|
||||||
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'
|
|
||||||
|
@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
|
|||||||
weight = np.transpose(weight_np, [1, 0]) # n,k
|
weight = np.transpose(weight_np, [1, 0]) # n,k
|
||||||
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
|
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
|
||||||
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
|
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
|
||||||
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (
|
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
|
||||||
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
|
|
||||||
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
|
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
|
||||||
return quanted_weight, weight_scales.astype(np.float32)
|
return quanted_weight, weight_scales.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def np_quant_weight(weight_np, algo='weight_only_int8'):
|
def np_quant_weight(weight_np, algo="weight_only_int8"):
|
||||||
assert weight_np.dtype == np.float32
|
assert weight_np.dtype == np.float32
|
||||||
|
|
||||||
if algo == 'weight_only_int4':
|
if algo == "weight_only_int4":
|
||||||
return np_quant_weight_int4(weight_np)
|
return np_quant_weight_int4(weight_np)
|
||||||
|
|
||||||
weight = np.transpose(weight_np, [1, 0])
|
weight = np.transpose(weight_np, [1, 0])
|
||||||
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
|
|||||||
def int8_to_bin(value):
|
def int8_to_bin(value):
|
||||||
if not -128 <= value <= 127:
|
if not -128 <= value <= 127:
|
||||||
raise ValueError("int8 值必须在 -128 到 127 之间")
|
raise ValueError("int8 值必须在 -128 到 127 之间")
|
||||||
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零
|
return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
|
||||||
|
|
||||||
|
|
||||||
# 1) preparation
|
# 1) preparation
|
||||||
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
|
|||||||
qw_np, wscale_np = np_quant_weight(w_np, algo)
|
qw_np, wscale_np = np_quant_weight(w_np, algo)
|
||||||
|
|
||||||
# 3) xpu calculation
|
# 3) xpu calculation
|
||||||
dtype = 'float32'
|
dtype = "float32"
|
||||||
x_pd = paddle.to_tensor(w_np, dtype=dtype)
|
x_pd = paddle.to_tensor(w_np, dtype=dtype)
|
||||||
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
|
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
|
||||||
qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
||||||
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
|
|||||||
# comparation
|
# comparation
|
||||||
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
|
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
|
||||||
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
|
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
|
||||||
print(
|
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
|
||||||
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}"
|
print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
|
||||||
)
|
sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
|
||||||
print(
|
|
||||||
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
|
|
||||||
)
|
|
||||||
sum_diff = np.sum(
|
|
||||||
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
|
|
||||||
print(f"sum_diff: {sum_diff}")
|
print(f"sum_diff: {sum_diff}")
|
||||||
|
@@ -37,4 +37,4 @@ python benchmark_serving.py \
|
|||||||
--num-prompts 1 \
|
--num-prompts 1 \
|
||||||
--max-concurrency 1 \
|
--max-concurrency 1 \
|
||||||
--save-result
|
--save-result
|
||||||
```
|
```
|
||||||
|
@@ -15,7 +15,7 @@ We provide two transmission methods for KV Cache, targeting intra-machine and in
|
|||||||
Uses cudaMemcpyPeer for KV Cache transmission between two GPUs within a single machine, offering low latency and high throughput.
|
Uses cudaMemcpyPeer for KV Cache transmission between two GPUs within a single machine, offering low latency and high throughput.
|
||||||
|
|
||||||
### Inter-machine Transmission
|
### Inter-machine Transmission
|
||||||
For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission.
|
For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission.
|
||||||
|
|
||||||
## PD Disaggregated Scheduling
|
## PD Disaggregated Scheduling
|
||||||

|

|
||||||
@@ -60,7 +60,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--cache-queue-port 8187 \
|
--cache-queue-port 8187 \
|
||||||
--tensor-parallel-size 4 \
|
--tensor-parallel-size 4 \
|
||||||
--quantization wint4 \
|
--quantization wint4 \
|
||||||
--innode-prefill-ports 8182 \
|
--innode-prefill-ports 8182 \
|
||||||
--splitwise-role "decode"
|
--splitwise-role "decode"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -72,7 +72,8 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
|
|||||||
### Multi-machine Disaggregated Deployment
|
### Multi-machine Disaggregated Deployment
|
||||||
|
|
||||||
#### Prerequisite: Redis
|
#### Prerequisite: Redis
|
||||||
- Installation via `conda`
|
* Installation via `conda`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install
|
# Install
|
||||||
conda install redis
|
conda install redis
|
||||||
@@ -80,7 +81,8 @@ conda install redis
|
|||||||
nohup redis-server > redis.log 2>&1 &
|
nohup redis-server > redis.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
- Installation via `apt`
|
* Installation via `apt`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install
|
# Install
|
||||||
sudo apt install redis-server -y
|
sudo apt install redis-server -y
|
||||||
@@ -88,7 +90,8 @@ sudo apt install redis-server -y
|
|||||||
sudo systemctl start redis-server
|
sudo systemctl start redis-server
|
||||||
```
|
```
|
||||||
|
|
||||||
- Installation via `yum`
|
* Installation via `yum`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install
|
# Install
|
||||||
sudo yum install redis -y
|
sudo yum install redis -y
|
||||||
|
@@ -38,6 +38,7 @@ conda install redis
|
|||||||
# Launch
|
# Launch
|
||||||
nohup redis-server > redis.log 2>&1 &
|
nohup redis-server > redis.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
### apt installation (Debian/Ubuntu)
|
### apt installation (Debian/Ubuntu)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -57,6 +58,7 @@ sudo systemctl start redis
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Launching FastDeploy
|
## Launching FastDeploy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m fastdeploy.entrypoints.openai.api_server \
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
--port 8801 \
|
--port 8801 \
|
||||||
@@ -72,6 +74,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--scheduler-min-load_score 3 \
|
--scheduler-min-load_score 3 \
|
||||||
--scheduler-load-shards-num 1
|
--scheduler-load-shards-num 1
|
||||||
```
|
```
|
||||||
|
|
||||||
[Scheduler Launching Parameter](../online_serving/scheduler.md)
|
[Scheduler Launching Parameter](../online_serving/scheduler.md)
|
||||||
|
|
||||||
### Deployment notes:
|
### Deployment notes:
|
||||||
|
@@ -36,4 +36,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
|
|
||||||
Set `enable_prefix_caching=True` when launching FastDeploy. Enable CPU caching via `swap_space` based on available machine memory.
|
Set `enable_prefix_caching=True` when launching FastDeploy. Enable CPU caching via `swap_space` based on available machine memory.
|
||||||
|
|
||||||
A test example is provided: `demo/offline_prefix_caching_demo.py`
|
A test example is provided: `demo/offline_prefix_caching_demo.py`
|
||||||
|
@@ -18,8 +18,9 @@ Interfaces that support toggling the reasoning mode:
|
|||||||
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
|
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
|
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
|
||||||
This parser will process the model's output and extract the `reasoning_content` field.
|
This parser will process the model's output and extract the `reasoning_content` field.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m fastdeploy.entrypoints.openai.api_server \
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
--model /path/to/your/model \
|
--model /path/to/your/model \
|
||||||
@@ -29,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--quantization wint4 \
|
--quantization wint4 \
|
||||||
--reasoning-parser ernie-45-vl
|
--reasoning-parser ernie-45-vl
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, make a request to the model that should return the reasoning content in the response.
|
Next, make a request to the model that should return the reasoning content in the response.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
@@ -43,10 +46,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
|
|||||||
"metadata": {"enable_thinking": true}
|
"metadata": {"enable_thinking": true}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
|
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
|
||||||
|
|
||||||
### Streaming chat completions
|
### Streaming chat completions
|
||||||
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
|
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||||
@@ -69,4 +74,4 @@ for chunk in chat_response:
|
|||||||
if chunk.choices[0].delta is not None:
|
if chunk.choices[0].delta is not None:
|
||||||
print(chunk.choices[0].delta, end='')
|
print(chunk.choices[0].delta, end='')
|
||||||
print("\n")
|
print("\n")
|
||||||
```
|
```
|
||||||
|
@@ -10,22 +10,22 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
|||||||
|
|
||||||
- **Ngram**
|
- **Ngram**
|
||||||
|
|
||||||
- **MTP (Multi-Token Prediction)**
|
- **MTP (Multi-Token Prediction)**
|
||||||
- ✅ Supported: TP Sharding
|
- ✅ Supported: TP Sharding
|
||||||
- ✅ Supported: Shared Prefix
|
- ✅ Supported: Shared Prefix
|
||||||
- ✅ Supported: TP Sharding + PD Separation
|
- ✅ Supported: TP Sharding + PD Separation
|
||||||
- ⏳ Coming Soon: EP + DP + PD Separation
|
- ⏳ Coming Soon: EP + DP + PD Separation
|
||||||
- ⏳ Coming Soon: Support Chunk-prefill
|
- ⏳ Coming Soon: Support Chunk-prefill
|
||||||
- ⏳ Coming Soon: Multi-layer MTP Layer
|
- ⏳ Coming Soon: Multi-layer MTP Layer
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### Coming Soon
|
### Coming Soon
|
||||||
|
|
||||||
- Draft Model
|
- Draft Model
|
||||||
- Eagle
|
- Eagle
|
||||||
- Hydra
|
- Hydra
|
||||||
- Medusa
|
- Medusa
|
||||||
- ...
|
- ...
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -54,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
|||||||
|
|
||||||
## 🚀 Using Multi-Token Prediction (MTP)
|
## 🚀 Using Multi-Token Prediction (MTP)
|
||||||
|
|
||||||
For detailed theory, refer to:
|
For detailed theory, refer to:
|
||||||
📄 [DeepSeek-V3 Paper](https://arxiv.org/pdf/2412.19437)
|
📄 [DeepSeek-V3 Paper](https://arxiv.org/pdf/2412.19437)
|
||||||
|
|
||||||
### TP Sharding Mode
|
### TP Sharding Mode
|
||||||
@@ -147,4 +147,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
|
||||||
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@@ -132,4 +132,3 @@ Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
|
|||||||
```json
|
```json
|
||||||
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
|
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -6,4 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
|
|||||||
- [Kunlun XPU Installation](kunlunxin_xpu.md)
|
- [Kunlun XPU Installation](kunlunxin_xpu.md)
|
||||||
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
||||||
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
||||||
- [Hygon DCU Installation](hygon_dcu.md)
|
- [Hygon DCU Installation](hygon_dcu.md)
|
||||||
|
@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
|
|||||||
```
|
```
|
||||||
|
|
||||||
## 2. Start service
|
## 2. Start service
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
|
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
|
||||||
python -m fastdeploy.entrypoints.openai.api_server \
|
python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
|||||||
--gpu-memory-utilization=0.8
|
--gpu-memory-utilization=0.8
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Send requests
|
### Send requests
|
||||||
|
|
||||||
Send requests using either curl or Python
|
Send requests using either curl or Python
|
||||||
|
|
||||||
@@ -78,4 +79,4 @@ response = client.chat.completions.create(
|
|||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user