polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

7
.flake8 Normal file
View File

@@ -0,0 +1,7 @@
[flake8]
ignore = E203, E402, E501, E731, E741, W503, W605, E722
max-line-length = 119
# E402: module level import not at top of file
per-file-ignores =
__init__.py:F401,F403,E402

View File

@@ -2,7 +2,7 @@ name: CI
on:
pull_request:
branches:
branches:
- develop
- 'release/*'
workflow_dispatch:
@@ -86,4 +86,4 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci.sh
"
"

View File

@@ -2,7 +2,7 @@ name: CI_XPU
on:
pull_request:
branches:
branches:
- develop
- 'release/*'
workflow_dispatch:
@@ -63,7 +63,7 @@ jobs:
if [[ "$last_char" =~ [0-3] ]]; then
gpu_id="$last_char"
else
gpu_id="0"
gpu_id="0"
fi
FD_API_PORT=$((9180 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
@@ -84,4 +84,4 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_xpu.sh
"
"

View File

@@ -5,12 +5,27 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
- repo: https://github.com/psf/black.git
rev: 22.8.0
hooks:
- id: black
files: \.(py|pyi)$
additional_dependencies: [toml]
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
- id: ruff
args: [--output-format, github, --fix, --line-length=120]
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
# # 拼写检查
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
@@ -18,17 +33,13 @@ repos:
# - id: codespell
# additional_dependencies: ['tomli']
# args: ['--toml', 'pyproject.toml']
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29
hooks:
- id: pymarkdown
args: [fix]
args: ["-d", "MD029,MD031", fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:

View File

@@ -8,7 +8,7 @@
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
@@ -17,8 +17,8 @@
|
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
</p>
--------------------------------------------------------------------------------

View File

@@ -131,4 +131,4 @@ python benchmarks/benchmark_mtp.py \
--s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应
--dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集
--dataset-path测试数据集路径
```
```

View File

@@ -29,13 +29,13 @@ from typing import Optional
import aiohttp
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
"""Input for requesting LLMs via API"""
no: int
prompt: str
history_QA: Optional[dict]
@@ -55,6 +55,7 @@ class RequestFuncInput:
@dataclass
class RequestFuncOutput:
"""Output for requesting LLMs via API"""
no: int = 0
generated_text: str = ""
reasoning_content: str = ""
@@ -66,7 +67,7 @@ class RequestFuncOutput:
itl: list = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
prompt_tokens: int = 0 # 推理侧返回输入token数
prompt_tokens: int = 0 # 推理侧返回输入token数
error: str = ""
@@ -76,12 +77,9 @@ async def async_request_eb_openai_chat_completions(
) -> RequestFuncOutput:
"""Request an LLM using EB OpenAI"""
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Chat Completions API URL must end with 'completions'."
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
@@ -91,7 +89,7 @@ async def async_request_eb_openai_chat_completions(
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
}
# 超参由yaml传入
@@ -99,8 +97,8 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
headers = {
"Content-Type": "application/json",
@@ -115,16 +113,14 @@ async def async_request_eb_openai_chat_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter()
@@ -138,22 +134,20 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.output_tokens = usage.get(
"completion_tokens", 0)
output.prompt_tokens = usage.get(
"prompt_tokens", 0)
output.output_tokens = usage.get("completion_tokens", 0)
output.prompt_tokens = usage.get("prompt_tokens", 0)
most_recent_timestamp = timestamp
@@ -166,7 +160,12 @@ async def async_request_eb_openai_chat_completions(
output.latency = most_recent_timestamp - st
else:
error_text = await response.text()
print("####error response:", error_text, "####payload:", payload)
print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or ""
output.success = False
except Exception:
@@ -194,15 +193,14 @@ async def async_request_eb_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
}
# 超参由yaml传入
@@ -210,12 +208,12 @@ async def async_request_eb_openai_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:", json.dumps(payload, ensure_ascii=False))
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
output = RequestFuncOutput()
@@ -227,8 +225,7 @@ async def async_request_eb_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@@ -236,8 +233,7 @@ async def async_request_eb_openai_completions(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter()
@@ -250,7 +246,7 @@ async def async_request_eb_openai_completions(
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
# First token
if not first_chunk_received:
first_chunk_received = True
@@ -259,26 +255,23 @@ async def async_request_eb_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
generated_text += text or ""
most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage"):
output.prompt_tokens = usage.get(
"prompt_tokens")
output.output_tokens = usage.get(
"completion_tokens")
output.prompt_tokens = usage.get("prompt_tokens")
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
@@ -294,8 +287,8 @@ async def async_request_eb_openai_completions(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output))
print(f"final_output:{output}")
if pbar:
pbar.update(1)
@@ -310,8 +303,7 @@ async def async_request_tgi(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
params = {
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
@@ -358,8 +350,7 @@ async def async_request_tgi(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.arrival_time.append(data["arrival_time"])
@@ -388,8 +379,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
@@ -414,8 +404,7 @@ async def async_request_trt_llm(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
@@ -427,8 +416,7 @@ async def async_request_trt_llm(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@@ -453,8 +441,7 @@ async def async_request_deepspeed_mii(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
"""Request an LLM using Deepspeed MII"""
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"prompt": request_func_input.prompt,
@@ -472,19 +459,16 @@ async def async_request_deepspeed_mii(
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
async with session.post(url=request_func_input.api_url, json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][
"text"]
output.generated_text = parsed_resp["choices"][0]["text"]
elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0]
else:
output.error = ("Unexpected response format: "
"neither 'choices' nor 'text' found")
output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
output.success = False
output.success = True
else:
@@ -510,26 +494,22 @@ async def async_request_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
"prompt": request_func_input.prompt,
# "temperature": 0.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
#"stream_options": {
# "stream_options": {
# "include_usage": True,
#},
# },
}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -538,8 +518,7 @@ async def async_request_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@@ -547,8 +526,7 @@ async def async_request_openai_completions(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
data = json.loads(chunk)
@@ -569,21 +547,19 @@ async def async_request_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
@@ -606,25 +582,24 @@ async def async_request_openai_audio(
"""Request an LLM using OpenAI"""
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(
("transcriptions", "translations"
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
("transcriptions", "translations")
), "OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True
"stream_continuous_usage_stats": True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
@@ -639,9 +614,9 @@ async def async_request_openai_audio(
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav')
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
@@ -653,24 +628,20 @@ async def async_request_openai_audio(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
async with session.post(url=api_url, data=form, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
@@ -678,13 +649,11 @@ async def async_request_openai_audio(
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
@@ -718,8 +687,11 @@ ASYNC_REQUEST_FUNCS = {
}
OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions,
async_request_eb_openai_chat_completions)
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v
in (
async_request_openai_completions,
async_request_eb_openai_chat_completions,
)
]

View File

@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Callable, Optional, Union
from PIL import Image
from typing import Any, Optional, Union
from PIL import Image
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
no: int
prompt: Union[str, Any]
history_QA: Union[str, Any]
@@ -48,6 +49,7 @@ class SampleRequest:
class BenchmarkDataset(ABC):
"""BenchmarkDataset"""
DEFAULT_SEED = 0
IS_MULTIMODAL = False
@@ -68,8 +70,7 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = (random_seed
if random_seed is not None else self.DEFAULT_SEED)
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.data = None
self.hyperparameter_path = hyperparameter_path
self.hyperparameters = {}
@@ -85,8 +86,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
raise NotImplementedError(
"load_data must be implemented in subclasses.")
raise NotImplementedError("load_data must be implemented in subclasses.")
@abstractmethod
def sample(self, num_requests: int) -> list[SampleRequest]:
@@ -105,8 +105,7 @@ class BenchmarkDataset(ABC):
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
@@ -117,11 +116,9 @@ class BenchmarkDataset(ABC):
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
additional = random.choices(requests, k=num_requests - len(requests))
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
logger.info("Oversampled requests to reach %d total samples.", num_requests)
def is_valid_sequence(
@@ -141,14 +138,12 @@ def is_valid_sequence(
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len
< min_len)
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long
or combined_too_long)
return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
def process_image(image: Any) -> Mapping[str, Any]:
@@ -171,28 +166,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
}
if isinstance(image, str):
image_url = (image if image.startswith(
("http://", "file://")) else f"file://{image}")
image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes.")
raise ValueError(
f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
)
class EBDataset(BenchmarkDataset):
@@ -243,8 +235,7 @@ class EBDataset(BenchmarkDataset):
new_output_len = int(entry["max_dec_len"])
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
no=cnt,
@@ -252,17 +243,20 @@ class EBDataset(BenchmarkDataset):
prompt_len=self.prompt_len,
history_QA=[],
expected_output_len=new_output_len,
))
)
)
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples
class EBChatDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
prompt_len: int
def __init__(self, **kwargs) -> None:
@@ -296,8 +290,7 @@ class EBChatDataset(BenchmarkDataset):
new_output_len = int(entry.get("max_tokens", 12288))
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
no=cnt,
@@ -306,9 +299,9 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
))
)
)
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples

View File

@@ -18,28 +18,16 @@ import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union
import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_dataset import EBChatDataset, EBDataset
from benchmark_serving import benchmark
def prepare_input_requests(
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
}
try:
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
def main(args):
base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests(
args.num_prompts, args.dataset_name, args.dataset_path
)
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup
print("Starting warmup...")
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
send_one_batch(
base_url,
max_concurrency,
input_requests[0:max_concurrency],
True,
)
# Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up"
metric_header = "Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup(
@@ -130,11 +121,7 @@ def main(args):
s_itl,
record["mean_s_itl_ms"],
)
print(
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
print("=" * 50)

File diff suppressed because it is too large Load Diff

View File

@@ -24,9 +24,11 @@ import os
from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
metrics: dict[str, list],
extra_info: dict[str, Any]) -> list:
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace,
metrics: dict[str, list],
extra_info: dict[str, Any],
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
},
}
tp = record["benchmark"]["extra_info"]["args"].get(
"tensor_parallel_size")
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
records.append(record)
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder):
"""InfEncoder"""
def clear_inf(self, o: Any):
"""clear_inf"""
if isinstance(o, dict):
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
"""write_to_json"""
with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder)

View File

@@ -25,32 +25,32 @@ import os
import random
import time
import warnings
import yaml
import requests
import copy
from argparse import ArgumentParser as FlexibleArgumentParser
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput)
import requests
import yaml
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass
class BenchmarkMetrics:
"""Class containing all metrics that are used in this script"""
completed: int
total_input: int
total_output: int
@@ -133,8 +133,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
theta = 1.0 / (request_rate * burstiness)
for request in input_requests:
@@ -160,7 +159,7 @@ def calculate_metrics(
) -> tuple[BenchmarkMetrics, list[int]]:
"""Calculates various performance metrics based on the inputs and outputs."""
input_lens: list[int] = []
infer_input_lens: list[int] = [] # 推理侧输入token数
infer_input_lens: list[int] = [] # 推理侧输入token数
actual_output_lens: list[int] = []
total_input = 0
completed = 0
@@ -210,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token
if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) /
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
s_decodes.append(
(outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
completed += 1
else:
actual_output_lens.append(0)
@@ -224,16 +224,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +239,9 @@ def calculate_metrics(
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
"All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@@ -253,64 +250,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) *
1, # ttfts is empty if streaming is not supported by backend
mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1)
for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p))
for p in selected_percentiles],
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p))
for p in selected_percentiles],
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p))
for p in selected_percentiles],
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
)
return metrics, actual_output_lens
@@ -351,20 +334,22 @@ async def benchmark(
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
output_len=test_output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
test_prompt = None
test_output_len = None
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
output_len=test_output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
@@ -384,19 +369,16 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
print(f"开始时间:{datetime.now()}")
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
@@ -409,25 +391,26 @@ async def benchmark(
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt,
prompt_len=0,
history_QA=history_QA,
hyper_parameters=hyper_parameters,
api_url=api_url,
output_len=output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
prompt_len=0,
history_QA=history_QA,
hyper_parameters=hyper_parameters,
api_url=api_url,
output_len=output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body,
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
print(f"完成时间:{datetime.now()}")
if profile:
print("Stopping profiler...")
test_output_len = None
test_output_len = None
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
@@ -454,22 +437,16 @@ async def benchmark(
)
print("Benchmark complete!!!")
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
@@ -477,8 +454,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@@ -491,7 +467,6 @@ async def benchmark(
"reasoning_contents": [output.reasoning_content for output in outputs],
"errors": [output.error for output in outputs],
}
quick_result = copy.deepcopy(result)
def process_one_metric(
# E.g., "ttft"
@@ -505,24 +480,25 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
@@ -537,31 +513,31 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}"),
)
)
result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -581,6 +557,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
"""
快速评估
"""
def process_quick_metric(
metric_attribute_name: str,
metric_name: str,
@@ -588,7 +565,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
):
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
@@ -600,17 +577,17 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
):
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
quick_result[f"mean_{metric_attribute_name}"] = mean_value
print("\n\n\n")
print("{s:{c}^{n}}".format(s=' Benchmark Quick Summary ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Benchmark Quick Summary ", n=50, c="="))
process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
process_quick_metric("ttft", "TTFT", "Time to First Token")
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_quick_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_quick_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_quick_metric("itl", "ITL", "Inter-token Latency")
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_quick_metric("e2el", "E2EL", "End-to-end Latency")
@@ -633,12 +610,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. ")
f"{VALID_NAMES!s}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
"non-negative."
)
return goodput_config_dict
@@ -652,37 +631,43 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
"number in milliseconds."
) from err
return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any],
file_name: str) -> None:
def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
"median_ttft_ms",
"mean_ttft_ms",
"std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
for k in metrics},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
metrics={k: [results[k]] for k in metrics},
extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def check_health(api_base_url: str) -> bool:
health_url = api_base_url.rstrip("/") + "/health"
try:
@@ -697,6 +682,7 @@ def check_health(api_base_url: str) -> bool:
print(f"[HEALTH] Failed to connect to {health_url}: {e}")
return False
def main(args: argparse.Namespace):
"""Main entry point"""
print(args)
@@ -707,7 +693,6 @@ def main(args: argparse.Namespace):
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
@@ -717,23 +702,17 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"EB":
lambda: EBDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
"EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"EBChat":
lambda: EBChatDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
"EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
}
@@ -751,15 +730,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
"temperature": args.temperature
}.items() if v is not None
"temperature": args.temperature,
}.items()
if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError(
"Sampling parameters are only supported by openai-compatible "
"backends.")
raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -790,15 +768,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
)
)
# Save config and results to json
if args.save_result:
@@ -819,22 +796,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
if not args.save_detailed:
# Remove fields with too many data points
for field in [
"input_lens", "output_lens", "ttfts", "itls",
"generated_texts", "errors"
"input_lens",
"output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]:
if field in result_json:
del result_json[field]
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@@ -843,21 +821,19 @@ def main(args: argparse.Namespace):
# Save to file
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile:
with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
@@ -883,18 +859,29 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"],
choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument("--hyperparameter-path",
type=str,
default=None,
help="Path to the hyperparameter. ")
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--hyperparameter-path",
type=str,
default=None,
help="Path to the hyperparameter. ",
)
parser.add_argument(
"--max-concurrency",
type=int,
@@ -906,7 +893,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
@@ -917,7 +905,7 @@ if __name__ == "__main__":
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@@ -930,11 +918,13 @@ if __name__ == "__main__":
"--logprobs",
type=int,
default=None,
help=("Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"
),
)
parser.add_argument(
"--request-rate",
@@ -971,8 +961,7 @@ if __name__ == "__main__":
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
@@ -1013,35 +1002,38 @@ if __name__ == "__main__":
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl\".")
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1069,8 +1061,8 @@ if __name__ == "__main__":
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
@@ -1098,29 +1090,24 @@ if __name__ == "__main__":
"--random-prefix-len",
type=int,
default=0,
help=("Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
help=(
"Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
)
sampling_group = parser.add_argument_group("sampling parameters")
@@ -1128,52 +1115,58 @@ if __name__ == "__main__":
"--top-p",
type=float,
default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).")
"decoding (i.e. temperature==0.0).",
)
parser.add_argument(
'--tokenizer-mode',
"--tokenizer-mode",
type=str,
default="auto",
choices=['auto', 'slow', 'mistral', 'custom'],
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")
parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules",
nargs='+',
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
parser.add_argument(
"--lora-modules",
nargs="+",
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.",
)
args = parser.parse_args()

View File

@@ -7,4 +7,4 @@ tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
quantization: wint4
reasoning_parser: ernie-45-vl
reasoning_parser: ernie-45-vl

View File

@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334"
max_num_batched_tokens: 384
max_num_partial_prefills: 3
max_long_partial_prefills: 3
max_long_partial_prefills: 3

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333"
pd_comm_port: "2333"

View File

@@ -10,4 +10,4 @@ engine_worker_queue_port: 6677
num_gpu_blocks_override: 1024
cache_transfer_protocol: "rdma"
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
pd_comm_port: "2334"
pd_comm_port: "2334"

View File

@@ -10,4 +10,4 @@ splitwise_role: decode
engine_worker_queue_port: 6678
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334"
pd_comm_port: "2334"

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333"
pd_comm_port: "2333"

View File

@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334"
max_num_batched_tokens: 384
max_num_partial_prefills: 3
max_long_partial_prefills: 3
max_long_partial_prefills: 3

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333"
pd_comm_port: "2333"

View File

@@ -3,4 +3,4 @@ max_num_seqs: 75
gpu_memory_utilization: 0.85
kv_cache_ratio: 0.75
quantization: wint4
tensor_parallel_size: 4
tensor_parallel_size: 4

View File

@@ -3,4 +3,4 @@ max_num_seqs: 25
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.75
quantization: wint8
tensor_parallel_size: 4
tensor_parallel_size: 4

View File

@@ -1,3 +1,3 @@
metadata:
min_tokens: 32
max_tokens: 33
max_tokens: 33

View File

@@ -5,4 +5,4 @@ metadata:
max_tokens: 12288
repetition_penalty: 1.05
frequency_penalty: 0
presence_penalty: 0
presence_penalty: 0

View File

@@ -5,4 +5,4 @@ metadata:
max_tokens: 12288
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 1.5
presence_penalty: 1.5

View File

@@ -8,4 +8,4 @@ frequency_penalty: 0
presence_penalty: 0
skip_special_tokens: false
chat_template_kwargs:
enable_thinking: true
enable_thinking: true

View File

@@ -3,4 +3,4 @@ max_num_seqs: 64
gpu_memory_utilization: 0.9
tensor_parallel_size: 8
quantization: wint8
reasoning_parser: ernie-x1
reasoning_parser: ernie-x1

View File

@@ -40,4 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out);

View File

@@ -216,7 +216,7 @@ __global__ void append_dequant_cache_kv_c8(
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
tid % 8 * num_elems_per_128b<CacheT>();
@@ -330,7 +330,7 @@ __global__ void append_dequant_cache_kv_c8(
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
#ifdef C8_DEBUG
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
@@ -373,14 +373,14 @@ void AppendDequantCache(
paddle::Tensor *k_out,
paddle::Tensor *v_out,
const cudaStream_t& stream
) {
) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
constexpr int NUM_WARPS = 4;
int block_num = cache_num_blocks_x.data<int>()[0];
dim3 grids(block_num, 1, kv_num_heads);
dim3 blocks(32, NUM_WARPS);
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;

View File

@@ -41,7 +41,7 @@ __global__ void append_clear_cache_int8_block(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -115,7 +115,7 @@ __global__ void append_clear_cache_int4_block(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -484,7 +484,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -716,7 +716,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -1097,7 +1097,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -1403,7 +1403,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
@@ -1792,4 +1792,4 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
}
}
}
}

View File

@@ -582,4 +582,4 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out);

View File

@@ -39,4 +39,4 @@ void SpeculateWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out);
paddle::Tensor* value_cache_out);

View File

@@ -30,4 +30,4 @@ inline int getSMVersion()
return sm_major * 10 + sm_minor;
}
}
}

View File

@@ -136,4 +136,4 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
ElementAccumulator, DefaultScaleMode>;
};
} // namespace cutlass_extensions
} // namespace cutlass_extensions

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -54,7 +54,7 @@
///////////////////////////////////FP8 Accumulation///////////////////////////
//////////////////////////////////////////////////////////////////////////////
/// This class provides API to promote (add) or scale (multiply_add) the results
/// from the tensor core accumulators to the main accumulators when the number
/// from the tensor core accumulators to the main accumulators when the number
/// of MMAs reaches the max number of MMA interval specified by user, after that
/// the tensor core accumulators are zeroed.
//////////////////////////////////////////////////////////////////////////////
@@ -64,7 +64,7 @@ namespace cutlass::gemm::collective {
template <
class EngineAccum,
class LayoutAccum>
struct GmmaFP8AccumulationWithScale {
struct GmmaFP8AccumulationWithScale {
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
using ElementAccumulator = typename EngineAccum::value_type;
@@ -78,7 +78,7 @@ private:
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
uint32_t mma_count_; // current executed MMAs
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
// promote or `add` the partial accumulators to main accumulator (FADD).
CUTLASS_DEVICE
@@ -116,11 +116,11 @@ public:
TensorAccum &accum,
uint32_t accum_promotion_interval,
uint32_t mma_count_per_mainloop_iteration)
: accum_(accum),
: accum_(accum),
accum_promotion_interval_(accum_promotion_interval),
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
mma_count_(0),
reset_accum_flag_(0)
mma_count_(0),
reset_accum_flag_(0)
{
accum_temp_ = cute::make_fragment_like(accum);
}
@@ -129,14 +129,14 @@ public:
// Methods (Common)
//
CUTLASS_DEVICE
CUTLASS_DEVICE
TensorAccum& operator()() {
return accum_temp_;
}
/// prepare the MMA accumulators when initialization or zeroing is required.
CUTLASS_DEVICE
bool prepare_if_needed() {
bool prepare_if_needed() {
return reset_accum_flag_;
}

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -137,7 +137,7 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params;
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
static constexpr int NumProducerThreadEvents = 33;
static constexpr int NumProducerThreadEvents = 33;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
@@ -161,11 +161,11 @@ struct CollectiveMma<
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
// Block scaling gmem-to-smem copy atom
// Block scaling gmem-to-smem copy atom
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
// Block scaling smem layout
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
@@ -202,7 +202,7 @@ struct CollectiveMma<
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_B;
};
@@ -228,7 +228,7 @@ struct CollectiveMma<
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
// Block scaling factors for A and B
ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_B;
};
@@ -285,7 +285,7 @@ struct CollectiveMma<
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
@@ -346,7 +346,7 @@ struct CollectiveMma<
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
@@ -406,26 +406,26 @@ struct CollectiveMma<
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
Tensor gScaleA = local_tile(
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
Tensor gScaleA = local_tile(
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
Tensor cScaleA = local_tile(
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
Tensor cScaleA = local_tile(
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord));
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
@@ -455,7 +455,7 @@ struct CollectiveMma<
}
}
// Allocate predicate tensors for a_scales (since we can't guarantee that
// Allocate predicate tensors for a_scales (since we can't guarantee that
// all scales are valid, since we could have a partial tiles along M)
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
#pragma unroll
@@ -536,7 +536,7 @@ struct CollectiveMma<
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Block scaling
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
Layout<
@@ -548,17 +548,17 @@ struct CollectiveMma<
//
// Define C accumulators and A/B partitioning
//
// Layout of warp group to thread mapping
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
stride<0>(typename TiledMma::BLayout{}) == 0 and
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
@@ -590,7 +590,7 @@ struct CollectiveMma<
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Per block scale values for operand A and B
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
@@ -618,7 +618,7 @@ struct CollectiveMma<
}
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers.
scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL
@@ -668,7 +668,7 @@ struct CollectiveMma<
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
@@ -712,7 +712,7 @@ struct CollectiveMma<
++smem_pipe_read;
++smem_pipe_release;
}
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
warpgroup_fence_operand(accumulation());

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -50,4 +50,4 @@ struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm
} // namespace cutlass::gemm

View File

@@ -90,4 +90,4 @@ struct GemmMoeProblemVisitor
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -90,7 +90,7 @@ template <
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaMultistage :
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class

View File

@@ -57,7 +57,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
hasbias,
ElementD,
void>;
constexpr int ScaleMsPerTile = size<0>(TileShape{});
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
@@ -161,7 +161,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
arguments.scheduler.decomposition_mode = DecompositionMode::StreamK;
arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic;
}
Gemm gemm_op;

View File

@@ -170,4 +170,4 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
return false;
}
return true;
}
}

View File

@@ -148,4 +148,4 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
return false;
}
return true;
}
}

View File

@@ -54,7 +54,7 @@ public:
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int k) const = 0;
protected:
static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 16;

View File

@@ -93,8 +93,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
PD_BUILD_STATIC_OP(extract_text_token_output)
.Inputs({"max_seq_len",
"max_seq_len_index",
"mm_token_num_len",
"max_seq_len_index",
"mm_token_num_len",
"seq_lens_this_time",
"cu_seqlens_q",
"score_text"})

View File

@@ -105,7 +105,7 @@ __global__ void cudaCoreGemm(InputType const* __restrict__ act,
}
}
}
__syncthreads();
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
int32_t mid = ii / TILE_N, nid = ii % TILE_N;
@@ -188,4 +188,4 @@ bool cuda_core_gemm_launcher(GemmParams const& params) {
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);

View File

@@ -61,7 +61,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
st_idx += cur_st_len;
}
}
while (idx < seq_lens_origin) {
idx = idx + split_fuse_text_size;
if (idx >= seq_lens_origin) {
@@ -116,7 +116,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
while (ib < img_total && cur_img_len < chunk_image_token_number) {
int token_times = 4;
cur_img_len += (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times;
ib ++;
ib ++;
chunk_image_number ++;
}
image_chunk_selections_vector.emplace_back(chunk_image_number);

View File

@@ -88,7 +88,7 @@ void sent_key_value_by_remote_ptr(
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
@@ -107,25 +107,25 @@ void sent_key_value_by_remote_ptr(
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id,
block_size_byte);
#endif
cudaError_t err = cudaGetLastError();
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize();
@@ -140,7 +140,7 @@ void sent_key_value_by_remote_ptr(
#ifdef DEBUG_IPC_SENT
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
<<" local_device_id:" << local_device_id
<<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte
@@ -159,26 +159,26 @@ void sent_key_value_by_remote_ptr(
#endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte,
stream);
#endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id,
block_size_byte);
cudaDeviceSynchronize();
#endif
err = cudaGetLastError();
if ( err != cudaSuccess )
{
printf("CUDA Error: %s\n", cudaGetErrorString(err));
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
@@ -316,11 +316,11 @@ void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
cudaStreamSynchronize(cuda_stream);
}
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
.Attrs({ "block_num: int",
"local_device_id: int",
.Attrs({ "block_num: int",
"local_device_id: int",
"remote_device_id: int",
"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
@@ -332,4 +332,4 @@ PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
.Attrs({"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));

View File

@@ -57,5 +57,3 @@ paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
num_experts);
return token_nums_per_expert;
}

View File

@@ -737,7 +737,7 @@ void MoeFastHardamardWrapper(const T *x_data,
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) {

View File

@@ -124,4 +124,4 @@ class CubKeyValueSorter {
int num_bits_;
};
} // namespace phi
} // namespace phi

View File

@@ -360,10 +360,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -374,10 +374,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
@@ -518,12 +518,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z;
}
__syncthreads();
T val = T(threadDataExp * normalizing_factor);
// top_k
// top_k
using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -541,7 +541,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x < num_experts) {
cub_kvp inp_kvp;
int expert = threadIdx.x;
int expert = threadIdx.x;
inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val;
@@ -1065,7 +1065,7 @@ __global__ void initialize_moe_routing_kernel(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token,
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,
@@ -1088,7 +1088,7 @@ __global__ void initialize_moe_routing_kernel(
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row;
}
if (expanded_dest_row < active_rows) {
const int expert_idx = expert_idx_per_token[expanded_dest_row];
@@ -1130,7 +1130,7 @@ static void run(
const T* unpermuted_input,
OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token,
const int *expert_idx_per_token,
const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows,

View File

@@ -17,7 +17,7 @@
// topk warps
template<typename T, int VecSize>
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec;
const int bid = blockIdx.x;
@@ -32,7 +32,7 @@ __global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int
}
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
@@ -81,7 +81,7 @@ std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
permute_indices_per_token.data<int32_t>(),
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
topk_idx.data<int64_t>(),
token_num, topk, num_vecs,
token_num, topk, num_vecs,
hidden, max_tokens_per_expert
);
@@ -112,4 +112,4 @@ PD_BUILD_STATIC_OP(moe_deepgemm_permute)
.Inputs({"x", "topk_idx"})
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));

View File

@@ -232,12 +232,12 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
/**
* @brief Mixture of Experts (MoE) Expert Dispatch Operator
*
*
* This operator performs the following key functions:
* 1. Computes top-k experts for each input token based on gating scores
* 2. Permutes input tokens according to their selected experts for efficient expert processing
* 3. Computes prefix sums of tokens per expert for group_gemm optimization
*
*
* Inputs:
* - input: The input tensor to be routed to experts
* Shape: [total_tokens, hidden_size]
@@ -246,7 +246,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* Shape: [total_tokens, expert_num]
* dtype: must be float32
* - gating_correction_bias: Optional bias term for gating correction (expert_num)
*
*
* Outputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [moe_topk * total_tokens, hidden_size]
@@ -263,7 +263,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* - top_k_indices: Indices of selected top-k experts for each token
* Shape: [total_tokens, moe_topk]
* dtype: int32
*
*
* Attributes:
* - moe_topk: Number of experts to select for each token (k value in top-k routing)
* - group_moe: Whether to perform group softmax within the operator
@@ -272,7 +272,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* - topk_only_mode: Operation mode selector
* (true: only performs topk selection without softmax,
* false: performs full softmax+topk computation)
*
*
* Note:
* - The operator requires 2D input format [total_tokens, hidden_size]
* - For optimal performance, expert_num should be a power of 2 when possible
@@ -283,7 +283,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx",
"permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))

View File

@@ -263,4 +263,4 @@ PD_BUILD_OP(moe_redundant_topk_select)
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));

View File

@@ -106,4 +106,4 @@ template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4,
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
}
}

View File

@@ -36,4 +36,4 @@ struct msgdata {
struct msgdatakv {
long mtype;
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
};
};

View File

@@ -14,9 +14,10 @@
"""read_ids"""
import os
import numpy as np
import struct
import numpy as np
def deserialize_from_file(fp):
"""deserialize from file"""

View File

@@ -13,9 +13,10 @@
# limitations under the License.
"""read temp_ids from file"""
import os
import numpy as np
import struct
import numpy as np
def deserialize_from_file(fp):
"""

View File

@@ -15,7 +15,7 @@
#include "remote_cache_kv_ipc.h"
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data RemoteCacheKvIpc::kv_complete_signal_meta_data;
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query;
void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr;
bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
@@ -118,4 +118,3 @@ void CUDART_CB RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_que
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
}

View File

@@ -71,7 +71,7 @@ struct RemoteCacheKvIpc {
}
}
msg_sed.mtext[0] = encoder_count;
if (!inited) {
// just init once
const int msg_id = 1024 + rank;
@@ -90,7 +90,7 @@ struct RemoteCacheKvIpc {
assert(layer_id_ <= num_layers_);
}
};
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data;
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query;
static void* kv_complete_signal_identity_ptr;

View File

@@ -125,7 +125,7 @@ void group_wise_scale(ScaleT* scale,
}
}
std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input,
std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input,
int groupsize,
std::string scale_dtype) {
auto input_cpu = input.copy_to(paddle::CPUPlace(), false);
@@ -139,47 +139,47 @@ std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &in
if (groupsize > 0) {
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
group_wise_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(),
k,
group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(),
k,
n,
groupsize);
} else {
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
per_channel_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f);
per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(),
k,
per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(),
k,
n);
}
} else if (scale_dtype == "float16") {
if (groupsize > 0) {
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
group_wise_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::float16>(),
k,
group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::float16>(),
k,
n,
groupsize);
} else {
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
per_channel_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f);
per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::float16>(),
k,
per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(),
scale.data<phi::dtype::float16>(),
k,
n);
}
}
auto out = paddle::full({shape[1] / 2, shape[0]}, 0, paddle::DataType::INT8, paddle::CPUPlace());
preprocess_weights_for_mixed_gemm(
out.data<int8_t>(),
packed_int4.data<int8_t>(),
{k, n},
out.data<int8_t>(),
packed_int4.data<int8_t>(),
{k, n},
kernels::cutlass_kernels::QuantType::W4_AFP8,
false);
return {out, scale};

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -27,7 +27,7 @@
std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
const std::string shm_name,
const std::vector<int>& shape) {
const std::vector<int>& shape) {
volatile shmStruct *shm = NULL;
sharedMemoryInfo info;
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
@@ -62,4 +62,4 @@ PD_BUILD_STATIC_OP(share_external_data)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"shm_name: std::string", "shape: std::vector<int>"})
.SetKernelFn(PD_KERNEL(ShareExternalData));
.SetKernelFn(PD_KERNEL(ShareExternalData));

View File

@@ -19,7 +19,7 @@
// #define DEBUG_EAGLE_KERNEL
__global__ void ComputeOrderKernel(
const int* seq_lens_this_time,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
@@ -47,7 +47,7 @@ __global__ void ComputeOrderKernel(
printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
#endif
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
position_map[in_offset++] = out_offset++;
}
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
@@ -69,13 +69,13 @@ __global__ void ComputeOrderKernel(
in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i);
printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif
position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time;
}
}
}
}
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
@@ -208,7 +208,7 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
}
case paddle::DataType::BFLOAT16: {
return DispatchDtype<paddle::DataType::BFLOAT16>(
input,
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,

View File

@@ -72,7 +72,7 @@ __global__ void computeOrderKernel(
output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) {
for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", src_map[i]);
}
printf("\n");
@@ -187,4 +187,4 @@ PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
"seq_lens_this_time",
"step_idx"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));

View File

@@ -26,7 +26,7 @@ __global__ void RebuildAppendPaddingKernel(
const int seq_len,
const int dim_embed,
const size_t elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
@@ -42,7 +42,7 @@ __global__ void RebuildAppendPaddingKernel(
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
const int bias_idx = i % dim_embed;
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
Store<T, VecSize>(src_vec, &out[i]);
}
@@ -78,14 +78,14 @@ std::vector<paddle::Tensor> DispatchDtype(
GetNumBlocks(pack_num, &grid_size);
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
cum_offsets.data<int32_t>(),
seq_len_encoder.data<int32_t>(),
seq_len_decoder.data<int32_t>(),
output_padding_offset.data<int32_t>(),
max_seq_len,
dim_embed,
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
cum_offsets.data<int32_t>(),
seq_len_encoder.data<int32_t>(),
seq_len_decoder.data<int32_t>(),
output_padding_offset.data<int32_t>(),
max_seq_len,
dim_embed,
elem_nums);
return {out};
}
@@ -99,7 +99,7 @@ std::vector<paddle::Tensor> RebuildAppendPadding(
const paddle::Tensor& output_padding_offset,
const int max_seq_len) {
switch (full_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>(
@@ -137,7 +137,7 @@ std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Inputs({"full_hidden_states",
.Inputs({"full_hidden_states",
"cum_offsets",
"seq_len_encoder",
"seq_len_decoder",
@@ -146,4 +146,4 @@ PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(RebuildAppendPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));

View File

@@ -93,7 +93,7 @@ __global__ void speculate_free_and_reschedule(bool *stop_flags,
used_list_len[tid] = 0;
}
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
1) /
block_size] == -1) {
// 统计需要分配block的位置和总数
@@ -347,7 +347,7 @@ PD_BUILD_STATIC_OP(speculate_step_reschedule)
"next_tokens",
"first_token_ids",
"accept_num"})
.Attrs({"block_size: int",
.Attrs({"block_size: int",
"encoder_decoder_block_num: int",
"max_draft_tokens: int"})
.Outputs({"stop_flags_out",

View File

@@ -60,7 +60,7 @@ __global__ void recover_block_system_cache(int *recover_block_list, // [bsz]
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
ori_free_list_len = ori_free_list_len_tid0;
#ifdef DEBUG_STEP
printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n",
printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n",
recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len);
#endif
}
@@ -95,7 +95,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids,
@@ -178,7 +178,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
}
PD_BUILD_STATIC_OP(step_system_cache)
.Inputs({"stop_flags",
.Inputs({"stop_flags",
"seq_lens_this_time",
"ori_seq_lens_encoder",
"ori_seq_lens_decoder",

View File

@@ -68,26 +68,26 @@ void SwapCache(const paddle::Tensor& cache_gpu, // gpu
switch (cache_gpu.dtype()) {
case paddle::DataType::BFLOAT16:
return SwapCacheImpl<paddle::DataType::BFLOAT16>(
cache_gpu,
cache_cpu_ptr,
cache_gpu,
cache_cpu_ptr,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::FLOAT16:
return SwapCacheImpl<paddle::DataType::FLOAT16>(
cache_gpu,
cache_cpu_ptr,
cache_gpu,
cache_cpu_ptr,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
case paddle::DataType::UINT8:
return SwapCacheImpl<paddle::DataType::UINT8>(
cache_gpu,
cache_cpu_ptr,
cache_gpu,
cache_cpu_ptr,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
mode);
default:

View File

@@ -47,7 +47,7 @@ inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* nu
template<typename T, int VecSize>
__global__ void text_image_scatter_kernel(
T* input_ptr,
T* input_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
@@ -72,8 +72,8 @@ __global__ void text_image_scatter_kernel(
int32_t token_type_ids_num = token_type_ids[token_idx];
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll
for(int vi = 0; vi < VecSize; ++vi) {
text_imgaes_vec[vi] = input_ptr_vec[vi];
@@ -92,7 +92,7 @@ __global__ void text_image_scatter_kernel(
template<typename T, int VecSize>
__global__ void text_image_gather_kernel(
T* output_ptr,
T* output_ptr,
T* text_gather_ptr,
T* image_gather_ptr,
int32_t* token_type_ids,
@@ -131,8 +131,8 @@ __global__ void text_image_gather_kernel(
}
int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
}
}
@@ -159,7 +159,7 @@ void LaunchTextImageGatherScatter(
const int64_t tot_element_num = token_num * hidden_size;
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
const int block_size = 128;
int grid_index = (token_num + block_size - 1) / block_size;
constexpr int32_t kNumWaves = 16;
@@ -170,8 +170,8 @@ void LaunchTextImageGatherScatter(
if (is_scatter) {
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
@@ -181,8 +181,8 @@ void LaunchTextImageGatherScatter(
} else {
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
@@ -216,8 +216,8 @@ void TextImageGatherScatter(
PD_BUILD_STATIC_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
@@ -229,5 +229,5 @@ PD_BUILD_STATIC_OP(text_image_gather_scatter)
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -16,7 +16,7 @@
template <int VecSize>
__global__ void text_image_index_out_kernel(
int32_t* token_type_ids,
int32_t* token_type_ids,
int32_t* text_index,
int32_t* image_index,
const int64_t token_num
@@ -25,7 +25,7 @@ __global__ void text_image_index_out_kernel(
if (global_thread_idx >= 1) return;
int text_count = 0;
int images_count = 0;
for (int i = 0; i < token_num; ++i) {
// printf(" %d %d %d %d \n", text_index[i], text_count, images_count, i);
if (token_type_ids[i] == 0) {
@@ -60,5 +60,5 @@ PD_BUILD_STATIC_OP(text_image_index_out)
.Outputs({"text_index_out",
"image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));

View File

@@ -810,4 +810,4 @@ PD_BUILD_STATIC_OP(tune_cublaslt_gemm)
"is_test: bool",
"is_read_from_file: bool",
"path: std::string"})
.SetKernelFn(PD_KERNEL(TuneCublasltGemm));
.SetKernelFn(PD_KERNEL(TuneCublasltGemm));

View File

@@ -33,7 +33,7 @@ __global__ void update_inputs_beam_kernel(
if (block_idx == 0) {
seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index];
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
}
}
if (block_idx < seq_len) {
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
}
@@ -74,8 +74,8 @@ void UpdateInputesBeam(
PD_BUILD_STATIC_OP(update_inputs_beam)
.Inputs({"beam_width",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_this_time",
"seq_lens_encoder",
"input_ids",
"logits"})
.Outputs({"seq_lens_this_time_out",
@@ -86,4 +86,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"input_ids", "input_ids_out"},
{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));
.SetKernelFn(PD_KERNEL(UpdateInputesBeam));

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" setup for FastDeploy custom ops """
"""setup for FastDeploy custom ops"""
import importlib
import json
import os
@@ -41,8 +41,7 @@ ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy,
# which is not installed yet
envs = load_module_from_path('envs',
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
@@ -143,8 +142,7 @@ def get_nvcc_version():
"""
Get cuda version of nvcc.
"""
nvcc_output = subprocess.check_output(["nvcc", "--version"],
universal_newlines=True)
nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0])
@@ -160,13 +158,19 @@ def get_gencode_flags(archs):
for cc_val in cc_s:
if cc_val == 90:
arch_code = "90a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags
@@ -194,7 +198,7 @@ if paddle.is_compiled_with_rocm():
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
sources=[
sources = [
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/stop_generation.cu",
@@ -302,8 +306,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git",
cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
if not os.listdir(cutlass_dir):
raise ValueError("Git clone cutlass failed!")
@@ -312,8 +315,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
if not os.path.exists(deep_gemm_dir):
os.makedirs(deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git",
deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
if not os.listdir(deep_gemm_dir):
raise ValueError("Git clone DeepGEMM failed!")
cur_path = os.path.dirname(os.path.abspath(__file__))
@@ -347,15 +349,13 @@ elif paddle.is_compiled_with_cuda():
try:
shutil.copytree(src_dir, dst_dir)
except Exception as e:
raise RuntimeError(
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git",
json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
@@ -372,7 +372,7 @@ elif paddle.is_compiled_with_cuda():
"-Ithird_party/nlohmann_json/include",
]
nvcc_version = get_nvcc_version()
print(f'nvcc_version = {nvcc_version}')
print(f"nvcc_version = {nvcc_version}")
if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(get_sm_version(archs))
@@ -414,31 +414,24 @@ elif paddle.is_compiled_with_cuda():
# Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc >= 90: # Hopper and newer
if cc >= 90: # Hopper and newer
# SM90 (Hopper) specific auto-generation and flags
if cc == 90: # Only for SM90
if cc == 90: # Only for SM90
nvcc_compile_args += [
# The gencode for 90a is added in get_gencode_flags now
# "-gencode",
# "arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
]
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
@@ -450,14 +443,14 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
print("SM100 (Blackwell): Applying SM100 configurations.")
nvcc_compile_args += [
# The gencode for 100a is added in get_gencode_flags
# "-gencode",
# "arch=compute_100a,code=compute_100a",
"-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag
"-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
]
# Placeholder for SM100-specific kernel auto-generation scripts
@@ -469,18 +462,16 @@ elif paddle.is_compiled_with_cuda():
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada)
else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
# Common FP8 sources for SM89+
sources += [
@@ -493,7 +484,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu"
"gpu_ops/fused_hadamard_quant_fp8.cu",
]
sources += find_end_files(fp8_auto_gen_directory, ".cu")

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" setup for FASTDEPLOY base ops """
"""setup for FASTDEPLOY base ops"""
from paddle.utils.cpp_extension import CppExtension, setup
@@ -27,7 +27,8 @@ setup(
"cpu_ops/rebuild_padding.cc",
],
extra_compile_args=[
"-DPy_LIMITED_API=0x03090000", "-DPADDLE_ON_INFERENCE"
"-DPy_LIMITED_API=0x03090000",
"-DPADDLE_ON_INFERENCE",
],
),
)

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" setup for FASTDEPLOY custom cpu ops """
"""setup for FASTDEPLOY custom cpu ops"""
import os
import subprocess
import tarfile
@@ -26,8 +26,7 @@ ROOT_DIR = Path(__file__).parent.parent
# which is not installed yet
from .setup_ops import load_module_from_path
envs = load_module_from_path('envs',
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
BUILDING_ARCS = []
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -48,17 +48,26 @@ def get_candidate_configs(sm):
candidate_configs = list()
hasbias = ("false", "true")
KernelSchedule = (
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
EpilogueSchedule = ("TmaWarpSpecializedCooperative", )
KernelSchedule = ("KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>",)
EpilogueSchedule = ("TmaWarpSpecializedCooperative",)
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
for act_tag in [
("noact", "Identity"),
# ("relu", "ReLu"),
# ("gelu", "GELU"),
# ("relu", "ReLu"),
# ("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule, TileSchedule)])
candidate_configs.extend(
[
(
hasbias,
act_tag,
tiles,
KernelSchedule,
EpilogueSchedule,
TileSchedule,
)
]
)
return candidate_configs
@@ -66,16 +75,13 @@ def get_shape_str(tile_shape):
"""
return tile_shape string.
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule,
tile_schedule):
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, tile_schedule):
"""
check the cutlass config valid.
"""
@@ -304,13 +310,10 @@ def SubstituteTemplate(template, values_base):
SubstituteTemplate
"""
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
@@ -329,8 +332,7 @@ def parse_args():
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -346,15 +348,15 @@ def parse_args():
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
TileSchedule: (str),
sm: str,
inputs_type: str,
outputs_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
TileSchedule: str,
sm: str,
):
"""
generate_source_cu
@@ -369,8 +371,11 @@ def generate_source_cu(
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(
tile_config, kernel_schedule,
epilogue_schedule, tile_schedule):
tile_config,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
value_dict = {
"input_type": input_type,
@@ -385,30 +390,32 @@ def generate_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
generate_dir: str,
inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
"""
generate_launch_gemm_cus
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
TileSchedule: str = single_config[5]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_path = os.path.join(generate_dir, f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
@@ -418,19 +425,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule,
tile_schedule):
if not check_config_valid(
tile_config,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
head_all_code += SubstituteTemplate(
LaunchGemmDeclare, value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -444,19 +451,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
source_all_code = SubstituteTemplate(
LaunchGemmPart0, value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
@@ -476,16 +483,14 @@ def generate_launch_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
gemm_config_str = gemm_config_str.replace("<", "").replace(
">", "")
gemm_config_str = gemm_config_str.replace("<", "").replace(">", "")
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
@@ -495,19 +500,18 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
"""
generate_dispatch_gemm_cu
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
TileSchedule: str = single_config[5]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
for input_type in inputs_type:
@@ -530,9 +534,12 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -554,18 +561,18 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule,
tile_schedule):
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"tile_id":
str(tile_id),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"tile_id": str(tile_id),
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
@@ -610,12 +617,17 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu")
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -24,27 +24,28 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
"""
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
"""
@@ -299,8 +300,7 @@ def check_min_split_k(value):
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support.")
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
return ivalue
@@ -310,8 +310,7 @@ def check_max_split_k(value):
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support..")
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
return ivalue
@@ -320,8 +319,7 @@ def parse_args():
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -421,8 +419,7 @@ def generate_dual_gemm_source_cu(
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(
GemmSplitKDeclare, value_dict)
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
all_code += CommonTail
return all_code
@@ -449,12 +446,12 @@ def generate_launch_dual_gemm_cus(
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -467,12 +464,12 @@ def generate_launch_dual_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -498,16 +495,14 @@ def generate_launch_dual_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
# source_all_code += split_k_code
# source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -566,12 +561,12 @@ def generate_dispatch_dual_gemm_cu(
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -580,10 +575,12 @@ def generate_dispatch_dual_gemm_cu(
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
value_dict.update(
{
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
}
)
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
@@ -602,8 +599,7 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_dual_gemm_candidate_configs(
sm, min_split_k, max_split_k, min_stages, max_stages)
fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"

View File

@@ -19,8 +19,7 @@ import re
def get_candidate_tiles():
"""
"""
""" """
cta_shape = [
("<_64, _16, _128>"),
("<_64, _32, _128>"),
@@ -45,8 +44,7 @@ def get_candidate_tiles():
def get_dual_gemm_candidate_configs(sm):
"""
"""
""" """
tiles = get_candidate_tiles()
candidate_configs = list()
@@ -64,35 +62,27 @@ def get_dual_gemm_candidate_configs(sm):
("swiglu", "SiLu"),
("geglu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
""" """
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
""" """
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if tile_shape[
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
if tile_shape[0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
return False
return True
@@ -302,8 +292,7 @@ bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
def SubstituteTemplate(template, values):
"""
"""
""" """
text = template
changed = True
while changed:
@@ -318,10 +307,8 @@ def SubstituteTemplate(template, values):
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
""" """
parser = argparse.ArgumentParser(description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
@@ -336,17 +323,16 @@ def parse_args():
# generate source .cu
def generate_dual_gemm_source_cu(
inputs_type: (str),
biases_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
sm: str,
inputs_type: str,
biases_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
sm: str,
):
"""
"""
""" """
all_code = CommonHead
for input_type in inputs_type:
for bias_type in biases_type:
@@ -354,9 +340,7 @@ def generate_dual_gemm_source_cu(
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
@@ -370,28 +354,29 @@ def generate_dual_gemm_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_dual_gemm_cus(
generate_dir: (str), inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
generate_dir: str,
inputs_type: str,
biases_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
@@ -401,16 +386,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -422,16 +405,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
@@ -450,14 +431,13 @@ def generate_launch_dual_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
@@ -467,16 +447,14 @@ def generate_launch_dual_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
def generate_dispatch_dual_gemm_cu(inputs_type: str, biases_type: str, fuse_gemm_configs: tuple, sm: str):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
@@ -500,8 +478,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -520,8 +497,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
@@ -570,12 +546,15 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
biases_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
biases_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_dual_gemm_cu(
inputs_type,

View File

@@ -31,25 +31,26 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
base_configs.extend(
[
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
"""
获取候选的gemm算子配置列表。
@@ -353,8 +354,7 @@ def parse_args():
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -448,8 +448,7 @@ def generate_source_cu(
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmSplitKDeclare,
value_dict)
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
all_code += CommonTail
return all_code
@@ -473,9 +472,7 @@ def generate_launch_gemm_cus(
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -489,9 +486,7 @@ def generate_launch_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -517,17 +512,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
split_k_code += SubstituteTemplate(
LaunchGemmPart3, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
source_all_code += split_k_code
source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -581,9 +573,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -593,10 +583,12 @@ def generate_dispatch_gemm_cu(
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
value_dict.update(
{
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
}
)
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
@@ -614,9 +606,7 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_split_k,
max_split_k, min_stages,
max_stages)
fuse_gemm_configs = get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
all_code = generate_source_cu(
@@ -654,9 +644,7 @@ if __name__ == "__main__":
# hard code for act_tag
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
)
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -20,44 +20,44 @@ import re
def get_candidate_tiles():
"""
"""
""" """
base_configs = [
("<_64, _64, _128>", "<_1, _8, _1>"),
("<_64, _128, _128>", "<_2, _1, _1>"),
("<_128, _128, _128>", "<_2, _1, _1>"),
]
base_configs.extend([
("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"),
("<_64, _64, _64>", "<_1, _1, _1>"),
("<_64, _64, _64>", "<_1, _2, _1>"),
("<_64, _64, _64>", "<_2, _1, _1>"),
("<_64, _128, _128>", "<_1, _2, _1>"),
("<_64, _128, _128>", "<_1, _1, _1>"),
("<_128, _128, _64>", "<_2, _1, _1>"),
("<_256, _128, _128>", "<_1, _2, _1>"),
("<_256, _128, _128>", "<_1, _1, _1>"),
# The following configurations are rarely selected in Qwen2-7B-model.
# ("<_256, _128, _128>", "<_4, _1, _1>"),
# ("<_256, _128, _128>", "<_1, _4, _1>"),
# ("<_256, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _256>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_4, _1, _1>"),
# ("<_128, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _128>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"),
])
base_configs.extend(
[
("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"),
("<_64, _64, _64>", "<_1, _1, _1>"),
("<_64, _64, _64>", "<_1, _2, _1>"),
("<_64, _64, _64>", "<_2, _1, _1>"),
("<_64, _128, _128>", "<_1, _2, _1>"),
("<_64, _128, _128>", "<_1, _1, _1>"),
("<_128, _128, _64>", "<_2, _1, _1>"),
("<_256, _128, _128>", "<_1, _2, _1>"),
("<_256, _128, _128>", "<_1, _1, _1>"),
# The following configurations are rarely selected in Qwen2-7B-model.
# ("<_256, _128, _128>", "<_4, _1, _1>"),
# ("<_256, _128, _128>", "<_1, _4, _1>"),
# ("<_256, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _256>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_4, _1, _1>"),
# ("<_128, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _128>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"),
]
)
return base_configs
def get_candidate_configs(sm):
"""
"""
""" """
tiles = get_candidate_tiles()
candidate_configs = list()
@@ -73,36 +73,31 @@ def get_candidate_configs(sm):
("relu", "ReLu"),
("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
""" """
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
""" """
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if (tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule):
if (
tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule
):
return False
return True
@@ -321,16 +316,12 @@ bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
def SubstituteTemplate(template, values_base):
"""
"""
""" """
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
@@ -345,10 +336,8 @@ def SubstituteTemplate(template, values_base):
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
""" """
parser = argparse.ArgumentParser(description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
@@ -363,17 +352,16 @@ def parse_args():
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
sm: str,
inputs_type: str,
outputs_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
sm: str,
):
"""
"""
""" """
all_code = CommonHead
for input_type in inputs_type:
@@ -382,9 +370,7 @@ def generate_source_cu(
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
@@ -398,25 +384,27 @@ def generate_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
generate_dir: str,
inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
@@ -426,16 +414,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -447,16 +433,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
@@ -475,14 +459,14 @@ def generate_launch_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu")
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -491,17 +475,15 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
@@ -524,8 +506,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -544,8 +525,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
@@ -576,7 +556,8 @@ if __name__ == "__main__":
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu")
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_source_cu(
inputs_type,
outputs_type,
@@ -594,8 +575,12 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"

View File

@@ -30,22 +30,24 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs
@@ -278,8 +280,7 @@ def parse_args():
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -370,13 +371,10 @@ def generate_launch_gemm_cus(
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。
"""
code_map = {}
head_path = os.path.join(generate_dir,
"launch_visitor_gemm_fused_kernel.h")
head_path = os.path.join(generate_dir, "launch_visitor_gemm_fused_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -390,9 +388,7 @@ def generate_launch_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -415,14 +411,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu")
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -485,9 +481,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -512,10 +506,11 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_stages,
max_stages)
fuse_gemm_configs = get_candidate_configs(sm, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
)
all_code = generate_source_cu(
inputs_type,
outputs_type,
@@ -544,9 +539,7 @@ if __name__ == "__main__":
sm_dict[sm],
)
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
)
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.kv_lod_vp = {
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1, nullptr};
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
nullptr,
0,

View File

@@ -30,8 +30,7 @@ current_file = Path(__file__).resolve()
base_dir = current_file.parent
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
XDNN_LIB_DIR):
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
"""
build xpu plugin
"""
@@ -49,7 +48,10 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 删除指定目录
dirs_to_remove = [
"dist", "fastdeploy_ops.egg-info", "build", "plugin/build"
"dist",
"fastdeploy_ops.egg-info",
"build",
"plugin/build",
]
for dir_name in dirs_to_remove:
if os.path.exists(dir_name):
@@ -58,8 +60,7 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 在 plugin 目录中执行构建脚本
plugin_dir = "plugin"
build_script = os.path.join(current_working_directory, plugin_dir,
"build.sh")
build_script = os.path.join(current_working_directory, plugin_dir, "build.sh")
print("build_script: ", build_script)
@@ -74,14 +75,16 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 执行构建脚本
try:
print("Running build script...")
subprocess.run([build_script],
check=True,
cwd=os.path.join(current_working_directory, plugin_dir))
subprocess.run(
[build_script],
check=True,
cwd=os.path.join(current_working_directory, plugin_dir),
)
print("Build completed successfully.")
except subprocess.CalledProcessError as e:
print(f"Build failed with error: {e}")
except Exception as e:
print(f"Unexpected error: {str(e)}")
print(f"Unexpected error: {e!s}")
def xpu_setup_ops():
@@ -124,17 +127,14 @@ def xpu_setup_ops():
XVLLM_PATH = os.getenv("XVLLM_PATH")
assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so",
"libapiinfer.so")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", "libapiinfer.so")
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so",
"libxft_blocks.so")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", "libxft_blocks.so")
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
# build plugin
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH,
XDNN_LIB_DIR)
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
ops = [
# custom ops
@@ -152,7 +152,6 @@ def xpu_setup_ops():
"./ops/block_attn.cc",
"./ops/moe_layer.cc",
"./ops/weight_quantize_xpu.cc",
# device manage ops
"./ops/device/get_context_gm_max_mem_demand.cc",
"./ops/device/get_free_global_memory.cc",

View File

@@ -29,7 +29,7 @@ for i in range(bs):
ids_len = seq_lens[i, 0]
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
(x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num),
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
print("cu_seqlens_q:\n", cu_seqlens_q)
print("cu_seqlens_k:\n", cu_seqlens_k)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6],
"int64")
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13],
"int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
assert sum(ref_x_remove_padding -
x_remove_padding) == 0, 'Check x_remove_padding failed.'
assert sum(ref_cum_offsets_out -
cum_offsets_out) == 0, 'Check cum_offsets_out failed.'
assert sum(ref_padding_offset -
padding_offset) == 0, 'Check padding_offset failed.'
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."

View File

@@ -21,10 +21,15 @@ paddle.seed(2023)
pre_ids = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64")
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]],
"float32")
"int64",
)
logits = paddle.to_tensor(
[
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
],
"float32",
)
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
@@ -88,78 +93,536 @@ ref_logits = np.array(
)
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.'
assert diff_logits < 1e-6, "Check failed."
pre_ids = paddle.to_tensor(
[[
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8,
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4,
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6,
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8,
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4,
7, 5, 5, 7, 6, 9, 3, 9
],
[
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1,
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1,
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8,
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6,
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6,
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8
]])
[
[
2,
3,
3,
5,
8,
9,
3,
9,
1,
8,
9,
2,
3,
8,
8,
9,
9,
1,
4,
2,
6,
2,
6,
8,
7,
2,
2,
3,
8,
1,
5,
7,
9,
2,
2,
9,
1,
4,
9,
8,
5,
8,
5,
7,
3,
6,
4,
4,
9,
9,
8,
5,
5,
2,
2,
9,
4,
8,
1,
9,
6,
9,
2,
2,
7,
2,
2,
9,
4,
6,
4,
6,
1,
4,
1,
9,
1,
8,
8,
5,
7,
9,
4,
2,
5,
1,
1,
4,
1,
5,
5,
4,
4,
2,
1,
8,
7,
1,
2,
9,
6,
7,
9,
6,
7,
7,
4,
9,
9,
7,
5,
1,
8,
9,
8,
8,
5,
4,
6,
4,
7,
5,
5,
7,
6,
9,
3,
9,
],
[
7,
8,
1,
3,
1,
7,
6,
3,
5,
3,
8,
3,
1,
9,
7,
1,
1,
9,
5,
4,
9,
6,
1,
9,
3,
8,
3,
9,
9,
6,
4,
2,
8,
5,
3,
1,
6,
9,
1,
3,
9,
8,
1,
7,
5,
1,
5,
1,
8,
7,
4,
5,
9,
8,
7,
4,
7,
3,
6,
4,
6,
6,
5,
5,
2,
9,
9,
5,
8,
8,
4,
8,
2,
8,
1,
3,
9,
1,
8,
5,
8,
3,
8,
8,
2,
7,
3,
7,
5,
7,
2,
6,
3,
5,
1,
4,
6,
1,
9,
8,
2,
2,
3,
6,
7,
6,
2,
6,
5,
1,
5,
6,
2,
1,
6,
4,
7,
7,
3,
8,
5,
1,
9,
1,
2,
8,
6,
8,
],
]
)
logits = paddle.to_tensor(
[[
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748,
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807,
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218,
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679,
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888,
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911,
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448,
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415,
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104,
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064,
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672,
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840,
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613,
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254,
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189,
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407,
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923,
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550,
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222,
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845,
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070,
0.98481447, 0.77367836
],
[
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417,
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683,
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915,
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503,
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967,
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547,
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396,
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451,
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258,
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694,
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542,
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293,
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045,
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974,
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888,
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515,
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780,
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544,
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529,
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276,
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621,
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523,
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239,
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721,
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736,
0.87317258, 0.88229448, 0.79217333
]])
[
[
0.16274983,
0.61470598,
0.94366980,
0.82005417,
0.50752640,
0.38316748,
0.92648441,
0.24050158,
0.05461595,
0.42218581,
0.36270225,
0.15464807,
0.13614719,
0.67509544,
0.40315166,
0.10671722,
0.24832056,
0.76091218,
0.11598995,
0.10962527,
0.04688513,
0.81536716,
0.72259802,
0.60476679,
0.16701800,
0.84160781,
0.79649884,
0.78021604,
0.75329530,
0.98587888,
0.13421868,
0.16027625,
0.15269397,
0.06228730,
0.73856270,
0.34721911,
0.73683006,
0.78178608,
0.32068327,
0.79906309,
0.44214272,
0.63330448,
0.08016958,
0.63367140,
0.19788943,
0.55346787,
0.11142531,
0.90518415,
0.21236691,
0.81587470,
0.83752930,
0.70979482,
0.35684183,
0.28715104,
0.87162822,
0.17679396,
0.98725849,
0.76129991,
0.04090235,
0.37181064,
0.63317049,
0.24689502,
0.21126501,
0.57617670,
0.74346697,
0.40613672,
0.56907010,
0.68556929,
0.29032683,
0.17866278,
0.35165095,
0.97015840,
0.70785582,
0.54259878,
0.14712237,
0.90483177,
0.02094105,
0.36411613,
0.02495066,
0.88874054,
0.88895452,
0.86216462,
0.58062190,
0.95583254,
0.20553111,
0.29870346,
0.69652933,
0.36861244,
0.85316223,
0.50240189,
0.17566244,
0.61080140,
0.88203174,
0.98675215,
0.24344546,
0.17213407,
0.78160852,
0.25165486,
0.48188508,
0.82812423,
0.10199814,
0.90475923,
0.66907483,
0.71910626,
0.40660757,
0.59460294,
0.70212913,
0.90841550,
0.00329034,
0.11290466,
0.89654654,
0.69114941,
0.29473618,
0.62027222,
0.37333879,
0.98911142,
0.46510187,
0.65914583,
0.73022646,
0.12790845,
0.12817244,
0.43015456,
0.75011456,
0.43562204,
0.48086026,
0.75587070,
0.98481447,
0.77367836,
],
[
0.12336024,
0.74152875,
0.09191196,
0.99301219,
0.44764417,
0.01848883,
0.78326035,
0.99228370,
0.81447607,
0.02627683,
0.51033205,
0.98703283,
0.15247856,
0.77640921,
0.60799915,
0.87518770,
0.76818430,
0.86542630,
0.31795895,
0.04829503,
0.85567141,
0.30271924,
0.67515039,
0.59728831,
0.78710967,
0.75111693,
0.56837374,
0.49085775,
0.91510201,
0.59545547,
0.99482232,
0.59036905,
0.58267909,
0.28770933,
0.53237396,
0.95318258,
0.93987304,
0.61142951,
0.26737869,
0.52285451,
0.03479086,
0.61631846,
0.66777998,
0.15736090,
0.00447258,
0.37035006,
0.15281211,
0.95372260,
0.25963321,
0.61036694,
0.15020694,
0.19171195,
0.55252832,
0.00391038,
0.31052542,
0.96495175,
0.42586124,
0.05630261,
0.99728668,
0.01856293,
0.83201504,
0.10701843,
0.56434178,
0.38009524,
0.51095045,
0.13202040,
0.07133843,
0.75313550,
0.17111187,
0.80716974,
0.00172165,
0.83906764,
0.73240769,
0.85843354,
0.11042888,
0.07912333,
0.33689004,
0.22334915,
0.59059596,
0.52789515,
0.29831955,
0.39515004,
0.55602801,
0.83818001,
0.05865780,
0.25654668,
0.76624149,
0.35190639,
0.04158346,
0.59157544,
0.30779791,
0.94609004,
0.10759670,
0.65575141,
0.37828529,
0.29571742,
0.76361233,
0.72476572,
0.18568406,
0.85430276,
0.02057583,
0.76195669,
0.65507215,
0.69129735,
0.25084621,
0.75223947,
0.06064088,
0.20287007,
0.35887691,
0.75043523,
0.47575447,
0.40021798,
0.44464844,
0.67975360,
0.40443239,
0.71052992,
0.21782248,
0.50568426,
0.89037591,
0.06661721,
0.28788096,
0.70773387,
0.42428264,
0.80419677,
0.42710736,
0.87317258,
0.88229448,
0.79217333,
],
]
)
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
print("eos_token_id\n", eos_token_id)
ref_logits = np.array(
[[
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280,
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450,
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113,
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603,
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060,
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541,
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545,
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062,
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366,
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470,
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395,
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191,
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210,
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380,
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446,
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091,
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629,
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825,
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235,
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291,
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053,
1.51174140, 1.96962893, 1.54735672
[
[
-10000000000.0,
-10000000000.0,
1.88733959,
1.64010835,
1.01505280,
0.76633495,
1.85296881,
0.48100317,
0.10923190,
0.84437162,
0.72540450,
0.30929613,
0.27229437,
1.35019088,
0.80630332,
0.21343444,
0.49664113,
1.52182436,
0.23197991,
0.21925054,
0.09377026,
1.63073432,
1.44519603,
1.20953357,
0.33403599,
1.68321562,
1.59299767,
1.56043208,
1.50659060,
1.97175777,
0.26843736,
0.32055250,
0.30538794,
0.12457460,
1.47712541,
0.69443822,
1.47366011,
1.56357217,
0.64136654,
1.59812617,
0.88428545,
1.26660895,
0.16033916,
1.26734281,
0.39577886,
1.10693574,
0.22285062,
1.81036830,
0.42473382,
1.63174939,
1.67505860,
1.41958964,
0.71368366,
0.57430208,
1.74325645,
0.35358793,
1.97451699,
1.52259982,
0.08180470,
0.74362129,
1.26634097,
0.49379003,
0.42253003,
1.15235341,
1.48693395,
0.81227344,
1.13814020,
1.37113857,
0.58065367,
0.35732555,
0.70330191,
1.94031680,
1.41571164,
1.08519757,
0.29424474,
1.80966353,
0.04188210,
0.72823226,
0.04990132,
1.77748108,
1.77790904,
1.72432923,
1.16124380,
1.91166508,
0.41106221,
0.59740692,
1.39305866,
0.73722488,
1.70632446,
1.00480378,
0.35132489,
1.22160280,
1.76406348,
1.97350430,
0.48689091,
0.34426814,
1.56321704,
0.50330973,
0.96377015,
1.65624845,
0.20399629,
1.80951846,
1.33814967,
1.43821251,
0.81321514,
1.18920588,
1.40425825,
1.81683099,
0.00658068,
0.22580932,
1.79309309,
1.38229883,
0.58947235,
1.24054444,
0.74667758,
1.97822285,
0.93020374,
1.31829166,
1.46045291,
0.25581691,
0.25634488,
0.86030912,
1.50022912,
0.87124407,
0.96172053,
1.51174140,
1.96962893,
1.54735672,
],
[
-10000000000.0,
-10000000000.0,
-40000.0,
3.97204876,
1.79057670,
0.07395532,
3.13304138,
3.96913481,
3.25790429,
-40000.0,
2.04132819,
3.94813132,
0.60991424,
3.10563684,
2.43199658,
3.50075078,
3.07273722,
3.46170521,
1.27183580,
0.19318011,
3.42268562,
1.21087694,
2.70060158,
2.38915324,
3.14843869,
3.00446773,
2.27349496,
1.96343100,
3.66040802,
2.38182187,
3.97928929,
2.36147618,
2.33071637,
1.15083730,
2.12949586,
3.81273031,
3.75949216,
2.44571805,
1.06951475,
2.09141803,
0.13916343,
2.46527386,
2.67111993,
0.62944359,
0.01789032,
1.48140025,
0.61124843,
3.81489038,
1.03853285,
2.44146776,
0.60082775,
0.76684779,
2.21011329,
0.01564152,
1.24210167,
3.85980701,
1.70344496,
0.22521044,
3.98914671,
0.07425172,
3.32806015,
0.42807373,
2.25736713,
1.52038097,
2.04380178,
0.52808160,
0.28535372,
3.01254201,
0.68444747,
3.22867894,
0.00688660,
3.35627055,
2.92963076,
3.43373418,
0.44171551,
0.31649333,
1.34756017,
0.89339662,
2.36238384,
2.11158061,
1.19327819,
1.58060014,
2.22411203,
3.35272002,
0.23463120,
1.02618670,
3.06496596,
1.40762556,
0.16633384,
2.36630177,
1.23119164,
3.78436017,
0.43038681,
2.62300563,
1.51314116,
1.18286967,
3.05444932,
2.89906287,
0.74273622,
3.41721106,
0.08230332,
3.04782677,
2.62028861,
2.76518941,
1.00338483,
3.00895786,
0.24256352,
0.81148028,
1.43550766,
3.00174093,
1.90301788,
1.60087192,
1.77859378,
2.71901441,
1.61772954,
2.84211969,
0.87128991,
2.02273703,
3.56150365,
0.26646885,
1.15152383,
2.83093548,
1.69713056,
3.21678710,
1.70842946,
3.49269032,
3.52917790,
3.16869330,
],
],
[
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670,
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819,
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078,
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562,
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773,
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929,
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031,
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343,
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025,
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775,
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701,
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015,
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160,
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660,
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333,
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819,
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670,
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164,
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967,
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332,
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786,
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788,
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969,
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383,
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032,
3.52917790, 3.16869330
]],
"float32",
)
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.'
assert diff_logits < 1e-6, "Check failed."

View File

@@ -21,19 +21,30 @@ paddle.seed(2023)
pre_ids_all = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64")
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]],
"int64")
"int64",
)
input_ids = paddle.to_tensor(
[
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
],
"int64",
)
seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
step_idx = paddle.to_tensor([1, 1], "int64")
stop_flags = paddle.to_tensor([0, 1], "bool")
print("pre_ids_all\n", pre_ids_all)
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_idx,
stop_flags)
set_value_by_flags_and_idx(
pre_ids_all,
input_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
)
print("pre_ids_all\n", pre_ids_all)
print("input_ids\n", input_ids)
print("seq_lens_this_time\n", seq_lens_this_time)
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
)
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
print("diff_pre_ids_all\n", diff_pre_ids_all)
assert diff_pre_ids_all == 0, 'Check failed.'
assert diff_pre_ids_all == 0, "Check failed."

View File

@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(max_block_num - 1,
max_block_num - free_list_len - 1,
-1,
dtype="int32")
free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32")
@@ -53,19 +50,15 @@ encoder_block_id = 0
for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size -
1) // block_size - enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange(
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num
if dec_block_num > 0:
block_tables[
i, enc_block_num:enc_block_num +
dec_block_num] = free_list[free_list_len[0] - 1 -
dec_block_num:free_list_len[0] - 1]
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
1] = -1
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
]
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
# print("step_idx: ", step_idx)
# print("next_tokens: ", next_tokens)
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_lens,
recover_block_list, recover_lens, need_block_list, need_block_len,
used_list_len, free_list, free_list_len, input_ids, pre_ids,
step_idx, next_tokens, first_token_ids, block_size,
encoder_decoder_block_num)
step_paddle(
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
block_size,
encoder_decoder_block_num,
)
print("-" * 50 + "after step op" + "-" * 50)
print("stop_flags: ", stop_flags)

View File

@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
False)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array(
[
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20,
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39,
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59,
60, 0, 0, 63
0,
0,
2,
3,
-1,
0,
0,
0,
0,
9,
10,
0,
12,
0,
-1,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
-1,
29,
30,
31,
0,
0,
0,
-1,
-1,
37,
38,
39,
-1,
-1,
0,
0,
0,
0,
46,
-1,
0,
49,
50,
0,
52,
53,
0,
-1,
0,
57,
-1,
59,
60,
0,
0,
63,
],
"int64",
)
ref_next_tokens = np.array(
[
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20,
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0,
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0,
0, 63
0,
0,
2,
3,
0,
0,
0,
0,
0,
9,
10,
0,
12,
0,
0,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
0,
29,
30,
31,
0,
0,
0,
0,
0,
37,
38,
39,
0,
0,
0,
0,
0,
0,
46,
0,
0,
49,
50,
0,
52,
53,
0,
0,
0,
57,
0,
59,
60,
0,
0,
63,
],
"int64",
)
ref_stop_flags = np.array(
[
True, True, True, True, True, True, True, True, True, False, False,
True, False, True, True, False, False, True, False, False, False, True,
False, False, True, False, False, False, True, False, False, False,
True, True, True, True, True, False, False, False, True, True, True,
True, True, True, False, True, True, False, False, True, False, False,
True, True, True, False, True, False, False, True, True, False
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
True,
False,
True,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
False,
True,
True,
True,
True,
True,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
True,
True,
False,
False,
True,
False,
False,
True,
True,
True,
False,
True,
False,
False,
True,
True,
False,
],
"bool",
)
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.'
assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.'
diff_stop_flags = np.sum(
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.'
assert diff_stop_flags == 0, "Check failed."
# test beam_search=True
topk_ids = paddle.arange(0, bs, dtype="int64")
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
True)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array(
[
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19,
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0,
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58,
-1, 60, 61, -1, 63
0,
1,
2,
3,
4,
0,
6,
7,
-1,
9,
10,
0,
-1,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
-1,
-1,
28,
29,
0,
0,
-1,
33,
34,
35,
36,
37,
0,
-1,
0,
41,
-1,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
-1,
60,
61,
-1,
63,
],
"int64",
)
ref_next_tokens = np.array(
[
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20,
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0,
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61,
0, 63
0,
1,
2,
3,
4,
0,
6,
7,
0,
9,
10,
0,
0,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
0,
0,
28,
29,
0,
0,
0,
33,
34,
35,
36,
37,
0,
0,
0,
41,
0,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
0,
60,
61,
0,
63,
],
"int64",
)
ref_stop_flags = np.array(
[
False, False, False, False, False, True, False, False, True, False,
False, True, True, False, False, False, True, False, False, False,
False, True, False, False, False, False, True, True, False, False,
True, True, True, False, False, False, False, False, True, True, True,
False, True, True, False, False, False, True, True, False, True, True,
True, False, True, True, True, True, False, True, False, False, True,
False
False,
False,
False,
False,
False,
True,
False,
False,
True,
False,
False,
True,
True,
False,
False,
False,
True,
False,
False,
False,
False,
True,
False,
False,
False,
False,
True,
True,
False,
False,
True,
True,
True,
False,
False,
False,
False,
False,
True,
True,
True,
False,
True,
True,
False,
False,
False,
True,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
False,
True,
False,
False,
True,
False,
],
"bool",
)
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.'
assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.'
diff_stop_flags = np.sum(
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.'
assert diff_stop_flags == 0, "Check failed."

View File

@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step)
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder,
seq_lens_decoder, input_ids, stop_nums, next_tokens,
is_block_step)
update_inputs(
stop_flags,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
next_tokens,
is_block_step,
)
print("-" * 50)
print("stop_flags:\n", stop_flags)
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
ref_not_need_stop_out = np.array([True])
ref_seq_lens_this_time_out = np.array([
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1
], "int32")
ref_seq_lens_encoder_out = np.array([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], "int32")
ref_seq_lens_decoder_out = np.array([
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0,
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0,
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], "int32")
input_ids_np[:, 0] = np.array([
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5,
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7,
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8
], "int64")
ref_seq_lens_this_time_out = np.array(
[
0,
0,
1,
0,
0,
1,
0,
1,
1,
1,
0,
1,
1,
0,
0,
0,
0,
0,
0,
0,
1,
1,
0,
1,
1,
0,
1,
1,
0,
0,
0,
1,
1,
0,
1,
0,
0,
1,
0,
1,
0,
0,
1,
0,
0,
1,
1,
1,
],
"int32",
)
ref_seq_lens_encoder_out = np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
ref_seq_lens_decoder_out = np.array(
[
0,
0,
2,
0,
0,
6,
0,
8,
8,
10,
0,
12,
12,
0,
0,
0,
0,
0,
0,
0,
20,
22,
0,
24,
24,
0,
26,
28,
0,
0,
0,
32,
32,
0,
34,
0,
0,
38,
0,
40,
0,
0,
42,
0,
0,
46,
46,
48,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
input_ids_np[:, 0] = np.array(
[
6,
5,
9,
8,
6,
2,
8,
1,
3,
1,
3,
6,
9,
8,
1,
9,
1,
8,
8,
6,
7,
6,
5,
3,
5,
9,
3,
6,
3,
9,
8,
8,
8,
8,
4,
8,
7,
4,
2,
3,
5,
8,
4,
2,
5,
6,
8,
9,
6,
7,
4,
2,
4,
6,
2,
3,
4,
9,
7,
2,
1,
8,
7,
8,
],
"int64",
)
assert not_need_stop.numpy(
) == ref_not_need_stop_out, 'Check not_need_stop failed.'
assert np.all(seq_lens_this_time.numpy() ==
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.'
assert np.all(seq_lens_encoder.numpy() ==
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
assert np.all(seq_lens_decoder.numpy() ==
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'
assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."

View File

@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
weight = np.transpose(weight_np, [1, 0]) # n,k
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32)
def np_quant_weight(weight_np, algo='weight_only_int8'):
def np_quant_weight(weight_np, algo="weight_only_int8"):
assert weight_np.dtype == np.float32
if algo == 'weight_only_int4':
if algo == "weight_only_int4":
return np_quant_weight_int4(weight_np)
weight = np.transpose(weight_np, [1, 0])
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
def int8_to_bin(value):
if not -128 <= value <= 127:
raise ValueError("int8 值必须在 -128 到 127 之间")
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零
return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
# 1) preparation
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
qw_np, wscale_np = np_quant_weight(w_np, algo)
# 3) xpu calculation
dtype = 'float32'
dtype = "float32"
x_pd = paddle.to_tensor(w_np, dtype=dtype)
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
# comparation
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
print(
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}"
)
print(
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
)
sum_diff = np.sum(
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"sum_diff: {sum_diff}")

View File

@@ -37,4 +37,4 @@ python benchmark_serving.py \
--num-prompts 1 \
--max-concurrency 1 \
--save-result
```
```

View File

@@ -15,7 +15,7 @@ We provide two transmission methods for KV Cache, targeting intra-machine and in
Uses cudaMemcpyPeer for KV Cache transmission between two GPUs within a single machine, offering low latency and high throughput.
### Inter-machine Transmission
For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission.
For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission.
## PD Disaggregated Scheduling
![Splitwise Scheduler](./images/disaggregated.png)
@@ -60,7 +60,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--cache-queue-port 8187 \
--tensor-parallel-size 4 \
--quantization wint4 \
--innode-prefill-ports 8182 \
--innode-prefill-ports 8182 \
--splitwise-role "decode"
```
@@ -72,7 +72,8 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
### Multi-machine Disaggregated Deployment
#### Prerequisite: Redis
- Installation via `conda`
* Installation via `conda`
```bash
# Install
conda install redis
@@ -80,7 +81,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 &
```
- Installation via `apt`
* Installation via `apt`
```bash
# Install
sudo apt install redis-server -y
@@ -88,7 +90,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server
```
- Installation via `yum`
* Installation via `yum`
```bash
# Install
sudo yum install redis -y

View File

@@ -38,6 +38,7 @@ conda install redis
# Launch
nohup redis-server > redis.log 2>&1 &
```
### apt installation (Debian/Ubuntu)
```bash
@@ -57,6 +58,7 @@ sudo systemctl start redis
```
## Launching FastDeploy
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \
@@ -72,6 +74,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \
--scheduler-load-shards-num 1
```
[Scheduler Launching Parameter](../online_serving/scheduler.md)
### Deployment notes:

View File

@@ -36,4 +36,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
Set `enable_prefix_caching=True` when launching FastDeploy. Enable CPU caching via `swap_space` based on available machine memory.
A test example is provided: `demo/offline_prefix_caching_demo.py`
A test example is provided: `demo/offline_prefix_caching_demo.py`

View File

@@ -18,8 +18,9 @@ Interfaces that support toggling the reasoning mode:
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
### Quick Start
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
This parser will process the model's output and extract the `reasoning_content` field.
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \
@@ -29,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--reasoning-parser ernie-45-vl
```
Next, make a request to the model that should return the reasoning content in the response.
```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \
@@ -43,10 +46,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
"metadata": {"enable_thinking": true}
}'
```
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
### Streaming chat completions
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
```python
from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.
@@ -69,4 +74,4 @@ for chunk in chat_response:
if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta, end='')
print("\n")
```
```

View File

@@ -10,22 +10,22 @@ This project implements an efficient **Speculative Decoding** inference framewor
- **Ngram**
- **MTP (Multi-Token Prediction)**
- ✅ Supported: TP Sharding
- ✅ Supported: Shared Prefix
- ✅ Supported: TP Sharding + PD Separation
- **MTP (Multi-Token Prediction)**
- ✅ Supported: TP Sharding
- ✅ Supported: Shared Prefix
- ✅ Supported: TP Sharding + PD Separation
- ⏳ Coming Soon: EP + DP + PD Separation
- ⏳ Coming Soon: Support Chunk-prefill
- ⏳ Coming Soon: Multi-layer MTP Layer
- ⏳ Coming Soon: Multi-layer MTP Layer
---
### Coming Soon
- Draft Model
- Eagle
- Hydra
- Medusa
- Draft Model
- Eagle
- Hydra
- Medusa
- ...
---
@@ -54,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
## 🚀 Using Multi-Token Prediction (MTP)
For detailed theory, refer to:
For detailed theory, refer to:
📄 [DeepSeek-V3 Paper](https://arxiv.org/pdf/2412.19437)
### TP Sharding Mode
@@ -147,4 +147,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
```
```

View File

@@ -132,4 +132,3 @@ Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
```

View File

@@ -6,4 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
- [Kunlun XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md)
- [Hygon DCU Installation](hygon_dcu.md)

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
```
## 2. Start service
```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8
```
#### Send requests
### Send requests
Send requests using either curl or Python
@@ -78,4 +79,4 @@ response = client.chat.completions.create(
stream=False,
)
print(response)
```
```

Some files were not shown because too many files have changed in this diff Show More