polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

7
.flake8 Normal file
View File

@@ -0,0 +1,7 @@
[flake8]
ignore = E203, E402, E501, E731, E741, W503, W605, E722
max-line-length = 119
# E402: module level import not at top of file
per-file-ignores =
__init__.py:F401,F403,E402

View File

@@ -5,12 +5,27 @@ default_stages:
- pre-commit # Run locally - pre-commit # Run locally
# - manual # Run in CI # - manual # Run in CI
repos: repos:
- repo: https://github.com/psf/black.git
rev: 22.8.0
hooks:
- id: black
files: \.(py|pyi)$
additional_dependencies: [toml]
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
# 代码检查 # 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7 rev: v0.11.7
hooks: hooks:
- id: ruff - id: ruff
args: [--output-format, github, --fix, --line-length=120] args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
# # 拼写检查 # # 拼写检查
# - repo: https://github.com/codespell-project/codespell # - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1 # rev: v2.4.1
@@ -18,17 +33,13 @@ repos:
# - id: codespell # - id: codespell
# additional_dependencies: ['tomli'] # additional_dependencies: ['tomli']
# args: ['--toml', 'pyproject.toml'] # args: ['--toml', 'pyproject.toml']
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
# markdown # markdown
- repo: https://github.com/jackdewinter/pymarkdown - repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29 rev: v0.9.29
hooks: hooks:
- id: pymarkdown - id: pymarkdown
args: [fix] args: ["-d", "MD029,MD031", fix]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 rev: v5.0.0
hooks: hooks:

View File

@@ -29,13 +29,13 @@ from typing import Optional
import aiohttp import aiohttp
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass @dataclass
class RequestFuncInput: class RequestFuncInput:
"""Input for requesting LLMs via API""" """Input for requesting LLMs via API"""
no: int no: int
prompt: str prompt: str
history_QA: Optional[dict] history_QA: Optional[dict]
@@ -55,6 +55,7 @@ class RequestFuncInput:
@dataclass @dataclass
class RequestFuncOutput: class RequestFuncOutput:
"""Output for requesting LLMs via API""" """Output for requesting LLMs via API"""
no: int = 0 no: int = 0
generated_text: str = "" generated_text: str = ""
reasoning_content: str = "" reasoning_content: str = ""
@@ -76,12 +77,9 @@ async def async_request_eb_openai_chat_completions(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
"""Request an LLM using EB OpenAI""" """Request an LLM using EB OpenAI"""
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
("completions", "profile")
), "OpenAI Chat Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
@@ -91,7 +89,7 @@ async def async_request_eb_openai_chat_completions(
"stream": True, "stream": True,
"stream_options": { "stream_options": {
"include_usage": True, "include_usage": True,
"continuous_usage_stats": True "continuous_usage_stats": True,
}, },
} }
# 超参由yaml传入 # 超参由yaml传入
@@ -100,7 +98,7 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False))) print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -115,16 +113,14 @@ async def async_request_eb_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk)) # print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter() timestamp = time.perf_counter()
@@ -138,22 +134,20 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st ttft = timestamp - st
output.ttft = ttft output.ttft = ttft
# cached_tokens # cached_tokens
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
output.generated_text += content or "" output.generated_text += content or ""
output.reasoning_content += reason_content or "" output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp)) output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}): elif usage := data.get("usage", {}):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens", 0)
"completion_tokens", 0) output.prompt_tokens = usage.get("prompt_tokens", 0)
output.prompt_tokens = usage.get(
"prompt_tokens", 0)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -166,7 +160,12 @@ async def async_request_eb_openai_chat_completions(
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
error_text = await response.text() error_text = await response.text()
print("####error response:", error_text, "####payload:", payload) print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or "" output.error = error_text or ""
output.success = False output.success = False
except Exception: except Exception:
@@ -194,15 +193,14 @@ async def async_request_eb_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"stream": True, "stream": True,
"stream_options": { "stream_options": {
"include_usage": True, "include_usage": True,
"continuous_usage_stats": True "continuous_usage_stats": True,
}, },
} }
# 超参由yaml传入 # 超参由yaml传入
@@ -215,7 +213,7 @@ async def async_request_eb_openai_completions(
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
output = RequestFuncOutput() output = RequestFuncOutput()
@@ -227,8 +225,7 @@ async def async_request_eb_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
@@ -236,8 +233,7 @@ async def async_request_eb_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage) # print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter() timestamp = time.perf_counter()
@@ -259,25 +255,22 @@ async def async_request_eb_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
generated_text += text or "" generated_text += text or ""
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time", timestamp)) output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.prompt_tokens = usage.get( output.prompt_tokens = usage.get("prompt_tokens")
"prompt_tokens") output.output_tokens = usage.get("completion_tokens")
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
"This response will be marked as failed!") )
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
@@ -295,7 +288,7 @@ async def async_request_eb_openai_completions(
exc_info = sys.exc_info() exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info)) output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output)) print(f"final_output:{output}")
if pbar: if pbar:
pbar.update(1) pbar.update(1)
@@ -310,8 +303,7 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
params = { params = {
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
"do_sample": True, "do_sample": True,
@@ -358,8 +350,7 @@ async def async_request_tgi(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.arrival_time.append(data["arrival_time"]) output.arrival_time.append(data["arrival_time"])
@@ -388,8 +379,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
"text_input": request_func_input.prompt, "text_input": request_func_input.prompt,
@@ -414,8 +404,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"] output.generated_text += data["text_output"]
@@ -427,8 +416,7 @@ async def async_request_trt_llm(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -453,8 +441,7 @@ async def async_request_deepspeed_mii(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
"""Request an LLM using Deepspeed MII""" """Request an LLM using Deepspeed MII"""
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
@@ -472,19 +459,16 @@ async def async_request_deepspeed_mii(
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(url=request_func_input.api_url, json=payload) as response:
json=payload) as response:
if response.status == 200: if response.status == 200:
parsed_resp = await response.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
if "choices" in parsed_resp: if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][ output.generated_text = parsed_resp["choices"][0]["text"]
"text"]
elif "text" in parsed_resp: elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
else: else:
output.error = ("Unexpected response format: " output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
"neither 'choices' nor 'text' found")
output.success = False output.success = False
output.success = True output.success = True
else: else:
@@ -510,11 +494,9 @@ async def async_request_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
# "temperature": 0.0, # "temperature": 0.0,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
@@ -527,9 +509,7 @@ async def async_request_openai_completions(
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@@ -538,8 +518,7 @@ async def async_request_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
@@ -547,8 +526,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk)) # print("####chunk:", chunk, type(chunk))
data = json.loads(chunk) data = json.loads(chunk)
@@ -569,21 +547,19 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
"This response will be marked as failed!") )
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
@@ -606,25 +582,24 @@ async def async_request_openai_audio(
"""Request an LLM using OpenAI""" """Request an LLM using OpenAI"""
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(
("transcriptions", "translations" ("transcriptions", "translations")
)), "OpenAI Chat Completions API URL must end with 'transcriptions' " ), "OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`." "or `translations`."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
payload = { payload = {
"model": request_func_input.model_name \ "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
if request_func_input.model_name else request_func_input.model,
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"language": "en", "language": "en",
# Flattened due to multipart/form-data # Flattened due to multipart/form-data
"stream_include_usage": True, "stream_include_usage": True,
"stream_continuous_usage_stats": True "stream_continuous_usage_stats": True,
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
@@ -639,9 +614,9 @@ async def async_request_openai_audio(
buffer.seek(0) buffer.seek(0)
return buffer return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f: with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav') form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items(): for key, value in payload.items():
form.add_field(key, str(value)) form.add_field(key, str(value))
@@ -653,24 +628,20 @@ async def async_request_openai_audio(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(url=api_url, data=form, headers=headers) as response:
data=form,
headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get( content = choices[0]["delta"].get("content")
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
@@ -678,13 +649,11 @@ async def async_request_openai_audio(
# Decoding phase # Decoding phase
else: else:
output.itl.append( output.itl.append(timestamp - most_recent_timestamp)
timestamp - most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -718,8 +687,11 @@ ASYNC_REQUEST_FUNCS = {
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items() k
if v in (async_request_openai_completions, for k, v in ASYNC_REQUEST_FUNCS.items()
async_request_eb_openai_chat_completions) if v
in (
async_request_openai_completions,
async_request_eb_openai_chat_completions,
)
] ]

View File

@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Optional, Union from typing import Any, Optional, Union
from PIL import Image
from PIL import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ class SampleRequest:
""" """
Represents a single inference request for benchmarking. Represents a single inference request for benchmarking.
""" """
no: int no: int
prompt: Union[str, Any] prompt: Union[str, Any]
history_QA: Union[str, Any] history_QA: Union[str, Any]
@@ -48,6 +49,7 @@ class SampleRequest:
class BenchmarkDataset(ABC): class BenchmarkDataset(ABC):
"""BenchmarkDataset""" """BenchmarkDataset"""
DEFAULT_SEED = 0 DEFAULT_SEED = 0
IS_MULTIMODAL = False IS_MULTIMODAL = False
@@ -68,8 +70,7 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the # Set the random seed, ensuring that a None value is replaced with the
# default seed. # default seed.
self.random_seed = (random_seed self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
if random_seed is not None else self.DEFAULT_SEED)
self.data = None self.data = None
self.hyperparameter_path = hyperparameter_path self.hyperparameter_path = hyperparameter_path
self.hyperparameters = {} self.hyperparameters = {}
@@ -85,8 +86,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method. NotImplementedError: If a subclass does not implement this method.
""" """
# TODO (jenniferzhao): add support for downloading data # TODO (jenniferzhao): add support for downloading data
raise NotImplementedError( raise NotImplementedError("load_data must be implemented in subclasses.")
"load_data must be implemented in subclasses.")
@abstractmethod @abstractmethod
def sample(self, num_requests: int) -> list[SampleRequest]: def sample(self, num_requests: int) -> list[SampleRequest]:
@@ -105,8 +105,7 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
num_requests: int) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
@@ -117,11 +116,9 @@ class BenchmarkDataset(ABC):
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = random.choices(requests, k=num_requests - len(requests))
k=num_requests - len(requests))
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.", num_requests)
num_requests)
def is_valid_sequence( def is_valid_sequence(
@@ -141,14 +138,12 @@ def is_valid_sequence(
""" """
# Check for invalid conditions # Check for invalid conditions
prompt_too_short = prompt_len < min_len prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
< min_len)
prompt_too_long = prompt_len > max_prompt_len prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met # Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
or combined_too_long)
def process_image(image: Any) -> Mapping[str, Any]: def process_image(image: Any) -> Mapping[str, Any]:
@@ -171,28 +166,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises: Raises:
ValueError: If the input is not a supported type. ValueError: If the input is not a supported type.
""" """
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = image.convert("RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
image_data.getvalue()).decode("utf-8")
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
"url": f"data:image/jpeg;base64,{image_base64}"
},
} }
if isinstance(image, str): if isinstance(image, str):
image_url = (image if image.startswith( image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
("http://", "file://")) else f"file://{image}")
return {"type": "image_url", "image_url": {"url": image_url}} return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" raise ValueError(
" or str or dictionary with raw image bytes.") f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
)
class EBDataset(BenchmarkDataset): class EBDataset(BenchmarkDataset):
@@ -243,8 +235,7 @@ class EBDataset(BenchmarkDataset):
new_output_len = int(entry["max_dec_len"]) new_output_len = int(entry["max_dec_len"])
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
no=cnt, no=cnt,
@@ -252,17 +243,20 @@ class EBDataset(BenchmarkDataset):
prompt_len=self.prompt_len, prompt_len=self.prompt_len,
history_QA=[], history_QA=[],
expected_output_len=new_output_len, expected_output_len=new_output_len,
)) )
)
cnt += 1 cnt += 1
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples
class EBChatDataset(BenchmarkDataset): class EBChatDataset(BenchmarkDataset):
""" """
Implements the ShareGPT dataset. Loads data from a JSON file and generates Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns. sample requests based on conversation turns.
""" """
prompt_len: int prompt_len: int
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
@@ -296,8 +290,7 @@ class EBChatDataset(BenchmarkDataset):
new_output_len = int(entry.get("max_tokens", 12288)) new_output_len = int(entry.get("max_tokens", 12288))
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
no=cnt, no=cnt,
@@ -306,9 +299,9 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0, prompt_len=0,
history_QA=history_QA, history_QA=history_QA,
expected_output_len=new_output_len, expected_output_len=new_output_len,
)) )
)
cnt += 1 cnt += 1
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples

View File

@@ -18,28 +18,16 @@ import argparse
import asyncio import asyncio
import contextlib import contextlib
import os import os
import signal
import socket
import subprocess
import time
from typing import Union from typing import Union
import openai from benchmark_dataset import EBChatDataset, EBDataset
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark from benchmark_serving import benchmark
def prepare_input_requests( def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = { dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample( "EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
num_requests=num_prompts "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
} }
try: try:
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
def main(args): def main(args):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests( input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
args.num_prompts, args.dataset_name, args.dataset_path
)
if len(args.max_concurrency) != len(args.s_itl_base_model): if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model") raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup # Wramup
print("Starting warmup...") print("Starting warmup...")
with open(os.devnull, "w") as f: with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f): with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True) send_one_batch(
base_url,
max_concurrency,
input_requests[0:max_concurrency],
True,
)
# Benchmark # Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False) record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up" metric_header = "Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps: for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup( speedup = calculate_speedup(
@@ -130,11 +121,7 @@ def main(args):
s_itl, s_itl,
record["mean_s_itl_ms"], record["mean_s_itl_ms"],
) )
print( print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("=" * 50) print("=" * 50)

View File

@@ -25,22 +25,23 @@ import os
import random import random
import time import time
import warnings import warnings
import yaml from argparse import ArgumentParser as FlexibleArgumentParser
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, import yaml
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, from backend_request_func import (
RequestFuncOutput) ASYNC_REQUEST_FUNCS,
from tqdm.asyncio import tqdm OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
from argparse import ArgumentParser as FlexibleArgumentParser RequestFuncOutput,
)
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset) from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -48,6 +49,7 @@ MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass @dataclass
class BenchmarkMetrics: class BenchmarkMetrics:
"""Class containing all metrics that are used in this script""" """Class containing all metrics that are used in this script"""
completed: int completed: int
total_input: int total_input: int
total_output: int total_output: int
@@ -130,8 +132,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests) input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
f"A positive burstiness factor is expected, but given {burstiness}.")
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for request in input_requests: for request in input_requests:
@@ -208,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1]) s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token # 解码速度去掉首token
if len(outputs[i].arrival_time) > 2: if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) / s_decodes.append(
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])) (outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
else: else:
print("len(outputs[i].arrival_time) <= 2") print("len(outputs[i].arrival_time) <= 2")
completed += 1 completed += 1
@@ -224,16 +226,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +241,9 @@ def calculate_metrics(
if completed == 0: if completed == 0:
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
"on the benchmark arguments.", stacklevel=2,
stacklevel=2) )
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@@ -253,64 +252,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) * mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1, std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1, median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
for p in selected_percentiles], mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles], mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000, std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000, median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000, mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000, std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000, median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000, mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000, std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000, median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1, mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1, std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1, median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1, mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1, std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1, median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1, mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1, std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1, median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@@ -344,9 +329,11 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
test_prompt, test_output_len, test_no = \ test_prompt, test_output_len, test_no = (
input_requests[0].prompt, \ input_requests[0].prompt,
input_requests[0].expected_output_len, input_requests[0].no input_requests[0].expected_output_len,
input_requests[0].no,
)
test_history_QA = input_requests[0].history_QA test_history_QA = input_requests[0].history_QA
test_input = RequestFuncInput( test_input = RequestFuncInput(
@@ -373,19 +360,19 @@ async def benchmark(
if not test_output.success: if not test_output.success:
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}") f"are correctly specified. Error: {test_output.error}"
)
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
if lora_modules: if lora_modules:
# For each input request, choose a LoRA module at random. # For each input request, choose a LoRA module at random.
lora_modules = iter( lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, profile_input = RequestFuncInput(
model=model_id,
model_name=model_name, model_name=model_name,
prompt=test_prompt, prompt=test_prompt,
no=test_no, no=test_no,
@@ -393,7 +380,8 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
@@ -413,21 +401,22 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, output_len, no = request.prompt, request.expected_output_len, request.no prompt, output_len, no = (
request.prompt,
request.expected_output_len,
request.no,
)
history_QA = request.history_QA history_QA = request.history_QA
req_model_id, req_model_name = model_id, model_name req_model_id, req_model_name = model_id, model_name
@@ -435,7 +424,8 @@ async def benchmark(
req_lora_module = next(lora_modules) req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id, request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name, model_name=req_model_name,
prompt=prompt, prompt=prompt,
no=no, no=no,
@@ -446,11 +436,9 @@ async def benchmark(
output_len=output_len, output_len=output_len,
logprobs=logprobs, logprobs=logprobs,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
tasks.append( )
asyncio.create_task( tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile: if profile:
@@ -473,7 +461,6 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
print("benchmark_duration:", benchmark_duration) print("benchmark_duration:", benchmark_duration)
metrics, actual_output_lens = calculate_metrics( metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests, input_requests=input_requests,
outputs=outputs, outputs=outputs,
@@ -483,22 +470,16 @@ async def benchmark(
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
@@ -506,8 +487,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@@ -533,24 +513,25 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):", f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) getattr(metrics, f"mean_{metric_attribute_name}_ms"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):", f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"))) getattr(metrics, f"median_{metric_attribute_name}_ms"),
result[f"mean_{metric_attribute_name}_ms"] = getattr( )
metrics, f"mean_{metric_attribute_name}_ms") )
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
metrics, f"median_{metric_attribute_name}_ms") result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
metrics, f"std_{metric_attribute_name}_ms") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length( def process_one_length(
@@ -565,31 +546,31 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:", f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}"))) getattr(metrics, f"mean_{metric_attribute_name}"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:", f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}"))) getattr(metrics, f"median_{metric_attribute_name}"),
result[f"mean_{metric_attribute_name}"] = getattr( )
metrics, f"mean_{metric_attribute_name}") )
result[f"median_{metric_attribute_name}"] = getattr( result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
metrics, f"median_{metric_attribute_name}") result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr( result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
metrics, f"std_{metric_attribute_name}") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
value))
result[f"p{p_word}_{metric_attribute_name}"] = value result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)") process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -612,7 +593,6 @@ def benchmark_metrics(
): ):
"""Benchmark metrics statisticsgenerate benchmark result""" """Benchmark metrics statisticsgenerate benchmark result"""
outputs = [] outputs = []
case_no_list = []
with open(result_file) as f: with open(result_file) as f:
for line in f.readlines(): for line in f.readlines():
if "RequestFuncOutput" in line: if "RequestFuncOutput" in line:
@@ -634,22 +614,16 @@ def benchmark_metrics(
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
@@ -657,8 +631,7 @@ def benchmark_metrics(
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@@ -682,24 +655,25 @@ def benchmark_metrics(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):", f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) getattr(metrics, f"mean_{metric_attribute_name}_ms"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):", f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"))) getattr(metrics, f"median_{metric_attribute_name}_ms"),
result[f"mean_{metric_attribute_name}_ms"] = getattr( )
metrics, f"mean_{metric_attribute_name}_ms") )
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
metrics, f"median_{metric_attribute_name}_ms") result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
metrics, f"std_{metric_attribute_name}_ms") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length( def process_one_length(
@@ -714,31 +688,31 @@ def benchmark_metrics(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:", f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}"))) getattr(metrics, f"mean_{metric_attribute_name}"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:", f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}"))) getattr(metrics, f"median_{metric_attribute_name}"),
result[f"mean_{metric_attribute_name}"] = getattr( )
metrics, f"mean_{metric_attribute_name}") )
result[f"median_{metric_attribute_name}"] = getattr( result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
metrics, f"median_{metric_attribute_name}") result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr( result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
metrics, f"std_{metric_attribute_name}") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
value))
result[f"p{p_word}_{metric_attribute_name}"] = value result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)") process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -764,12 +738,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{VALID_NAMES!s}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
@@ -783,32 +759,37 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
results: dict[str, Any],
file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file""" """Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [ metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", "median_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", "mean_ttft_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" "std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
] ]
# These raw data might be useful, but they are rather big. They can be added # These raw data might be useful, but they are rather big. They can be added
# later if needed # later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={k: [results[k]] metrics={k: [results[k]] for k in metrics},
for k in metrics}, extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
extra_info={ )
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
@@ -825,7 +806,6 @@ def main(args: argparse.Namespace):
model_id = args.model model_id = args.model
model_name = args.served_model_name model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
@@ -835,21 +815,15 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None: if args.dataset_name is None:
raise ValueError( raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"EB": "EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
), ),
"EBChat": "EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBChatDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
), ),
@@ -869,15 +843,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p, "top_p": args.top_p,
"top_k": args.top_k, "top_k": args.top_k,
"min_p": args.min_p, "min_p": args.min_p,
"temperature": args.temperature "temperature": args.temperature,
}.items() if v is not None }.items()
if v is not None
} }
# Sampling parameters are only supported by openai-compatible backend. # Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError( raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
"Sampling parameters are only supported by openai-compatible "
"backends.")
if "temperature" not in sampling_params: if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding. sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -908,15 +881,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
)) )
)
# benchmark_result = benchmark_metrics( # benchmark_result = benchmark_metrics(
# benchmark_duration=3600, # benchmark_duration=3600,
@@ -947,22 +919,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=") kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip() result_json[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
"Invalid metadata format. Please use KEY=VALUE format."
)
if not args.save_detailed: if not args.save_detailed:
# Remove fields with too many data points # Remove fields with too many data points
for field in [ for field in [
"input_lens", "output_lens", "ttfts", "itls", "input_lens",
"generated_texts", "errors" "output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
@@ -971,21 +944,19 @@ def main(args: argparse.Namespace):
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}" max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
if args.max_concurrency is not None else "") file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile: with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile) json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name) save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
description="Benchmark the online serving throughput.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
@@ -1011,18 +982,29 @@ if __name__ == "__main__":
"--dataset-name", "--dataset-name",
type=str, type=str,
default="sharegpt", default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"], choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument("--dataset-path", parser.add_argument(
"--dataset-path",
type=str, type=str,
default=None, default=None,
help="Path to the sharegpt/sonnet dataset. " help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
"Or the huggingface dataset ID if using HF dataset.") )
parser.add_argument("--hyperparameter-path", parser.add_argument(
"--hyperparameter-path",
type=str, type=str,
default=None, default=None,
help="Path to the hyperparameter. ") help="Path to the hyperparameter. ",
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@@ -1034,7 +1016,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
@@ -1045,7 +1028,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 help="Name or path of the tokenizer, if not using the default tokenizer.",
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
@@ -1058,11 +1041,13 @@ if __name__ == "__main__":
"--logprobs", "--logprobs",
type=int, type=int,
default=None, default=None,
help=("Number of logprobs-per-token to compute & return as part of " help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search " "the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy " "is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search " "logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"), "is enabled 1 logprob per token is computed"
),
) )
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
@@ -1099,8 +1084,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
help="Use Torch Profiler. The endpoint must be launched with " help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--save-result", "--save-result",
@@ -1141,35 +1125,38 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments # group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1197,8 +1184,8 @@ if __name__ == "__main__":
"--sharegpt-output-len", "--sharegpt-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output length " help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
"from the ShareGPT dataset.") )
random_group = parser.add_argument_group("random dataset options") random_group = parser.add_argument_group("random dataset options")
random_group.add_argument( random_group.add_argument(
@@ -1226,29 +1213,24 @@ if __name__ == "__main__":
"--random-prefix-len", "--random-prefix-len",
type=int, type=int,
default=0, default=0,
help=("Number of fixed prefix tokens before the random context " help=(
"Number of fixed prefix tokens before the random context "
"in a request. " "in a request. "
"The total input length is the sum of `random-prefix-len` and " "The total input length is the sum of `random-prefix-len` and "
"a random " "a random "
"context length sampled from [input_len * (1 - range_ratio), " "context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."), "input_len * (1 + range_ratio)]."
),
) )
hf_group = parser.add_argument_group("hf dataset options") hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
type=str, hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument( hf_group.add_argument(
"--hf-output-len", "--hf-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output lengths " help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
"from the sampled HF dataset.",
) )
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
@@ -1256,54 +1238,59 @@ if __name__ == "__main__":
"--top-p", "--top-p",
type=float, type=float,
default=None, default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible " help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--top-k", "--top-k",
type=int, type=int,
default=None, default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible " help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--min-p", "--min-p",
type=float, type=float,
default=None, default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible " help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--temperature", "--temperature",
type=float, type=float,
default=None, default=None,
help="Temperature sampling parameter. Only has effect on " help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy " "openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).") "decoding (i.e. temperature==0.0).",
)
parser.add_argument( parser.add_argument(
'--tokenizer-mode', "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
choices=['auto', 'slow', 'mistral', 'custom'], choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' "always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*' '"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.') '"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name", parser.add_argument(
"--served-model-name",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. " help="The model name used in the API. "
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ") "same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules", parser.add_argument(
nargs='+', "--lora-modules",
nargs="+",
default=None, default=None,
help="A subset of LoRA module names passed in when " help="A subset of LoRA module names passed in when "
"launching the server. For each request, the " "launching the server. For each request, the "
"script chooses a LoRA module at random.") "script chooses a LoRA module at random.",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@@ -24,9 +24,11 @@ import os
from typing import Any from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(
args: argparse.Namespace,
metrics: dict[str, list], metrics: dict[str, list],
extra_info: dict[str, Any]) -> list: extra_info: dict[str, Any],
) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
}, },
} }
tp = record["benchmark"]["extra_info"]["args"].get( tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
"tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata # Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info: if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
records.append(record) records.append(record)
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder): class InfEncoder(json.JSONEncoder):
"""InfEncoder""" """InfEncoder"""
def clear_inf(self, o: Any): def clear_inf(self, o: Any):
"""clear_inf""" """clear_inf"""
if isinstance(o, dict): if isinstance(o, dict):
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
"""write_to_json""" """write_to_json"""
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder) json.dump(records, f, cls=InfEncoder)

View File

@@ -25,32 +25,32 @@ import os
import random import random
import time import time
import warnings import warnings
import yaml from argparse import ArgumentParser as FlexibleArgumentParser
import requests
import copy
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, import requests
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, import yaml
RequestFuncOutput) from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass @dataclass
class BenchmarkMetrics: class BenchmarkMetrics:
"""Class containing all metrics that are used in this script""" """Class containing all metrics that are used in this script"""
completed: int completed: int
total_input: int total_input: int
total_output: int total_output: int
@@ -133,8 +133,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests) input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
f"A positive burstiness factor is expected, but given {burstiness}.")
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for request in input_requests: for request in input_requests:
@@ -210,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1]) s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token # 解码速度去掉首token
if len(outputs[i].arrival_time) > 2: if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) / s_decodes.append(
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])) (outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
completed += 1 completed += 1
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
@@ -224,16 +224,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +239,9 @@ def calculate_metrics(
if completed == 0: if completed == 0:
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
"on the benchmark arguments.", stacklevel=2,
stacklevel=2) )
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@@ -253,64 +250,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) * mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1, std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1, median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
for p in selected_percentiles], mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles], mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000, std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000, median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000, mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000, std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000, median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000, mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000, std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000, median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1, mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1, std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1, median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1, mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1, std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1, median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1, mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1, std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1, median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@@ -351,20 +334,22 @@ async def benchmark(
if lora_modules: if lora_modules:
# For each input request, choose a LoRA module at random. # For each input request, choose a LoRA module at random.
lora_modules = iter( lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, test_prompt = None
test_output_len = None
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name, model_name=model_name,
prompt=test_prompt, prompt=test_prompt,
api_url=base_url + "/start_profile", api_url=base_url + "/start_profile",
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
@@ -384,16 +369,13 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
@@ -409,7 +391,8 @@ async def benchmark(
req_lora_module = next(lora_modules) req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id, request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name, model_name=req_model_name,
prompt=prompt, prompt=prompt,
prompt_len=0, prompt_len=0,
@@ -419,15 +402,15 @@ async def benchmark(
output_len=output_len, output_len=output_len,
logprobs=logprobs, logprobs=logprobs,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=extra_body) extra_body=extra_body,
tasks.append( )
asyncio.create_task( tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
print(f"完成时间:{datetime.now()}") print(f"完成时间:{datetime.now()}")
if profile: if profile:
print("Stopping profiler...") print("Stopping profiler...")
test_output_len = None
test_output_len = None
profile_input = RequestFuncInput( profile_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_prompt, prompt=test_prompt,
@@ -454,22 +437,16 @@ async def benchmark(
) )
print("Benchmark complete!!!") print("Benchmark complete!!!")
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
@@ -477,8 +454,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@@ -491,7 +467,6 @@ async def benchmark(
"reasoning_contents": [output.reasoning_content for output in outputs], "reasoning_contents": [output.reasoning_content for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
quick_result = copy.deepcopy(result)
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
@@ -505,24 +480,25 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):", f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) getattr(metrics, f"mean_{metric_attribute_name}_ms"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):", f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"))) getattr(metrics, f"median_{metric_attribute_name}_ms"),
result[f"mean_{metric_attribute_name}_ms"] = getattr( )
metrics, f"mean_{metric_attribute_name}_ms") )
result[f"median_{metric_attribute_name}_ms"] = getattr( result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
metrics, f"median_{metric_attribute_name}_ms") result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr( result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
metrics, f"std_{metric_attribute_name}_ms") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length( def process_one_length(
@@ -537,31 +513,31 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:", f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}"))) getattr(metrics, f"mean_{metric_attribute_name}"),
print("{:<40} {:<10.2f}".format( )
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:", f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}"))) getattr(metrics, f"median_{metric_attribute_name}"),
result[f"mean_{metric_attribute_name}"] = getattr( )
metrics, f"mean_{metric_attribute_name}") )
result[f"median_{metric_attribute_name}"] = getattr( result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
metrics, f"median_{metric_attribute_name}") result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr( result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
metrics, f"std_{metric_attribute_name}") for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
value))
result[f"p{p_word}_{metric_attribute_name}"] = value result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)") process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -581,6 +557,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
""" """
快速评估 快速评估
""" """
def process_quick_metric( def process_quick_metric(
metric_attribute_name: str, metric_attribute_name: str,
metric_name: str, metric_name: str,
@@ -588,7 +565,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
): ):
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms") mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value)) print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
@@ -600,17 +577,17 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
): ):
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}") mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value)) print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
quick_result[f"mean_{metric_attribute_name}"] = mean_value quick_result[f"mean_{metric_attribute_name}"] = mean_value
print("\n\n\n") print("\n\n\n")
print("{s:{c}^{n}}".format(s=' Benchmark Quick Summary ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Benchmark Quick Summary ", n=50, c="="))
process_quick_length("s_decode", "Decode", "解码速度(tok/s)") process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
process_quick_metric("ttft", "TTFT", "Time to First Token") process_quick_metric("ttft", "TTFT", "Time to First Token")
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_quick_metric("tpot", "TPOT", process_quick_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_quick_metric("itl", "ITL", "Inter-token Latency") process_quick_metric("itl", "ITL", "Inter-token Latency")
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_quick_metric("e2el", "E2EL", "End-to-end Latency") process_quick_metric("e2el", "E2EL", "End-to-end Latency")
@@ -633,12 +610,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{VALID_NAMES!s}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
@@ -652,37 +631,43 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
results: dict[str, Any],
file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file""" """Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [ metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", "median_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", "mean_ttft_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" "std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
] ]
# These raw data might be useful, but they are rather big. They can be added # These raw data might be useful, but they are rather big. They can be added
# later if needed # later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={k: [results[k]] metrics={k: [results[k]] for k in metrics},
for k in metrics}, extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
extra_info={ )
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
def check_health(api_base_url: str) -> bool: def check_health(api_base_url: str) -> bool:
health_url = api_base_url.rstrip("/") + "/health" health_url = api_base_url.rstrip("/") + "/health"
try: try:
@@ -697,6 +682,7 @@ def check_health(api_base_url: str) -> bool:
print(f"[HEALTH] Failed to connect to {health_url}: {e}") print(f"[HEALTH] Failed to connect to {health_url}: {e}")
return False return False
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
"""Main entry point""" """Main entry point"""
print(args) print(args)
@@ -707,7 +693,6 @@ def main(args: argparse.Namespace):
model_id = args.model model_id = args.model
model_name = args.served_model_name model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
@@ -717,21 +702,15 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None: if args.dataset_name is None:
raise ValueError( raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"EB": "EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
), ),
"EBChat": "EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBChatDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
output_len=args.sharegpt_output_len, output_len=args.sharegpt_output_len,
), ),
@@ -751,15 +730,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p, "top_p": args.top_p,
"top_k": args.top_k, "top_k": args.top_k,
"min_p": args.min_p, "min_p": args.min_p,
"temperature": args.temperature "temperature": args.temperature,
}.items() if v is not None }.items()
if v is not None
} }
# Sampling parameters are only supported by openai-compatible backend. # Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError( raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
"Sampling parameters are only supported by openai-compatible "
"backends.")
if "temperature" not in sampling_params: if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding. sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -790,15 +768,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
)) )
)
# Save config and results to json # Save config and results to json
if args.save_result: if args.save_result:
@@ -819,22 +796,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=") kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip() result_json[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
"Invalid metadata format. Please use KEY=VALUE format."
)
if not args.save_detailed: if not args.save_detailed:
# Remove fields with too many data points # Remove fields with too many data points
for field in [ for field in [
"input_lens", "output_lens", "ttfts", "itls", "input_lens",
"generated_texts", "errors" "output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
@@ -843,21 +821,19 @@ def main(args: argparse.Namespace):
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}" max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
if args.max_concurrency is not None else "") file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile: with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile) json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name) save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
description="Benchmark the online serving throughput.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
@@ -883,18 +859,29 @@ if __name__ == "__main__":
"--dataset-name", "--dataset-name",
type=str, type=str,
default="sharegpt", default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"], choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument("--dataset-path", parser.add_argument(
"--dataset-path",
type=str, type=str,
default=None, default=None,
help="Path to the sharegpt/sonnet dataset. " help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
"Or the huggingface dataset ID if using HF dataset.") )
parser.add_argument("--hyperparameter-path", parser.add_argument(
"--hyperparameter-path",
type=str, type=str,
default=None, default=None,
help="Path to the hyperparameter. ") help="Path to the hyperparameter. ",
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@@ -906,7 +893,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
@@ -917,7 +905,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 help="Name or path of the tokenizer, if not using the default tokenizer.",
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
@@ -930,11 +918,13 @@ if __name__ == "__main__":
"--logprobs", "--logprobs",
type=int, type=int,
default=None, default=None,
help=("Number of logprobs-per-token to compute & return as part of " help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search " "the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy " "is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search " "logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"), "is enabled 1 logprob per token is computed"
),
) )
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
@@ -971,8 +961,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
help="Use Torch Profiler. The endpoint must be launched with " help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--save-result", "--save-result",
@@ -1013,35 +1002,38 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments # group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1069,8 +1061,8 @@ if __name__ == "__main__":
"--sharegpt-output-len", "--sharegpt-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output length " help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
"from the ShareGPT dataset.") )
random_group = parser.add_argument_group("random dataset options") random_group = parser.add_argument_group("random dataset options")
random_group.add_argument( random_group.add_argument(
@@ -1098,29 +1090,24 @@ if __name__ == "__main__":
"--random-prefix-len", "--random-prefix-len",
type=int, type=int,
default=0, default=0,
help=("Number of fixed prefix tokens before the random context " help=(
"Number of fixed prefix tokens before the random context "
"in a request. " "in a request. "
"The total input length is the sum of `random-prefix-len` and " "The total input length is the sum of `random-prefix-len` and "
"a random " "a random "
"context length sampled from [input_len * (1 - range_ratio), " "context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."), "input_len * (1 + range_ratio)]."
),
) )
hf_group = parser.add_argument_group("hf dataset options") hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
type=str, hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument( hf_group.add_argument(
"--hf-output-len", "--hf-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output lengths " help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
"from the sampled HF dataset.",
) )
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
@@ -1128,52 +1115,58 @@ if __name__ == "__main__":
"--top-p", "--top-p",
type=float, type=float,
default=None, default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible " help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--top-k", "--top-k",
type=int, type=int,
default=None, default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible " help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--min-p", "--min-p",
type=float, type=float,
default=None, default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible " help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--temperature", "--temperature",
type=float, type=float,
default=None, default=None,
help="Temperature sampling parameter. Only has effect on " help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy " "openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).") "decoding (i.e. temperature==0.0).",
)
parser.add_argument( parser.add_argument(
'--tokenizer-mode', "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
choices=['auto', 'slow', 'mistral', 'custom'], choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' "always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*' '"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.') '"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name", parser.add_argument(
"--served-model-name",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. " help="The model name used in the API. "
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the ``--model`` argument. ") "same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules", parser.add_argument(
nargs='+', "--lora-modules",
nargs="+",
default=None, default=None,
help="A subset of LoRA module names passed in when " help="A subset of LoRA module names passed in when "
"launching the server. For each request, the " "launching the server. For each request, the "
"script chooses a LoRA module at random.") "script chooses a LoRA module at random.",
)
args = parser.parse_args() args = parser.parse_args()

View File

@@ -57,5 +57,3 @@ paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
num_experts); num_experts);
return token_nums_per_expert; return token_nums_per_expert;
} }

View File

@@ -14,9 +14,10 @@
"""read_ids""" """read_ids"""
import os import os
import numpy as np
import struct import struct
import numpy as np
def deserialize_from_file(fp): def deserialize_from_file(fp):
"""deserialize from file""" """deserialize from file"""

View File

@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
"""read temp_ids from file""" """read temp_ids from file"""
import os import os
import numpy as np
import struct import struct
import numpy as np
def deserialize_from_file(fp): def deserialize_from_file(fp):
""" """

View File

@@ -118,4 +118,3 @@ void CUDART_CB RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_que
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal(); RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query); // std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
} }

View File

@@ -41,8 +41,7 @@ ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy, # cannot import envs directly because it depends on fastdeploy,
# which is not installed yet # which is not installed yet
envs = load_module_from_path('envs', envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
archs = json.loads(envs.FD_BUILDING_ARCS) archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True" use_bf16 = envs.FD_CPU_USE_BF16 == "True"
@@ -143,8 +142,7 @@ def get_nvcc_version():
""" """
Get cuda version of nvcc. Get cuda version of nvcc.
""" """
nvcc_output = subprocess.check_output(["nvcc", "--version"], nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0]) nvcc_cuda_version = float(output[release_idx].split(",")[0])
@@ -160,13 +158,19 @@ def get_gencode_flags(archs):
for cc_val in cc_s: for cc_val in cc_s:
if cc_val == 90: if cc_val == 90:
arch_code = "90a" arch_code = "90a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a' # Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/ # https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0" # "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a" arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else: else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"] flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags return flags
@@ -302,8 +306,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir): if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir): if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir) os.makedirs(cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
cutlass_dir)
if not os.listdir(cutlass_dir): if not os.listdir(cutlass_dir):
raise ValueError("Git clone cutlass failed!") raise ValueError("Git clone cutlass failed!")
@@ -312,8 +315,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir): if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
if not os.path.exists(deep_gemm_dir): if not os.path.exists(deep_gemm_dir):
os.makedirs(deep_gemm_dir) os.makedirs(deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
deep_gemm_dir)
if not os.listdir(deep_gemm_dir): if not os.listdir(deep_gemm_dir):
raise ValueError("Git clone DeepGEMM failed!") raise ValueError("Git clone DeepGEMM failed!")
cur_path = os.path.dirname(os.path.abspath(__file__)) cur_path = os.path.dirname(os.path.abspath(__file__))
@@ -347,15 +349,13 @@ elif paddle.is_compiled_with_cuda():
try: try:
shutil.copytree(src_dir, dst_dir) shutil.copytree(src_dir, dst_dir)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
json_dir = "third_party/nlohmann_json" json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir): if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir): if not os.path.exists(json_dir):
os.makedirs(json_dir) os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
json_dir)
if not os.listdir(json_dir): if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!") raise ValueError("Git clone nlohmann_json failed!")
@@ -372,7 +372,7 @@ elif paddle.is_compiled_with_cuda():
"-Ithird_party/nlohmann_json/include", "-Ithird_party/nlohmann_json/include",
] ]
nvcc_version = get_nvcc_version() nvcc_version = get_nvcc_version()
print(f'nvcc_version = {nvcc_version}') print(f"nvcc_version = {nvcc_version}")
if nvcc_version >= 12.0: if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"] sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(get_sm_version(archs)) cc = max(get_sm_version(archs))
@@ -414,9 +414,7 @@ elif paddle.is_compiled_with_cuda():
# Running generate fp8 gemm codes. # Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell) # Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"] nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [ nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
@@ -431,14 +429,9 @@ elif paddle.is_compiled_with_cuda():
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
] ]
print("SM90: Running SM90-specific FP8 kernel auto-generation.") print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system( os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [ nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1", "-DENABLE_SCALED_MM_SM90=1",
@@ -473,14 +466,12 @@ elif paddle.is_compiled_with_cuda():
else: # For cc >= 89 but not 90 or 100 (e.g. SM89) else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.") print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada) else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.") print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
# Common FP8 sources for SM89+ # Common FP8 sources for SM89+
sources += [ sources += [
@@ -493,7 +484,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu", "gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu", "gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu", "gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu" "gpu_ops/fused_hadamard_quant_fp8.cu",
] ]
sources += find_end_files(fp8_auto_gen_directory, ".cu") sources += find_end_files(fp8_auto_gen_directory, ".cu")

View File

@@ -27,7 +27,8 @@ setup(
"cpu_ops/rebuild_padding.cc", "cpu_ops/rebuild_padding.cc",
], ],
extra_compile_args=[ extra_compile_args=[
"-DPy_LIMITED_API=0x03090000", "-DPADDLE_ON_INFERENCE" "-DPy_LIMITED_API=0x03090000",
"-DPADDLE_ON_INFERENCE",
], ],
), ),
) )

View File

@@ -26,8 +26,7 @@ ROOT_DIR = Path(__file__).parent.parent
# which is not installed yet # which is not installed yet
from .setup_ops import load_module_from_path from .setup_ops import load_module_from_path
envs = load_module_from_path('envs', envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
BUILDING_ARCS = [] BUILDING_ARCS = []
use_bf16 = envs.FD_CPU_USE_BF16 == "True" use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -48,8 +48,7 @@ def get_candidate_configs(sm):
candidate_configs = list() candidate_configs = list()
hasbias = ("false", "true") hasbias = ("false", "true")
KernelSchedule = ( KernelSchedule = ("KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>",)
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
EpilogueSchedule = ("TmaWarpSpecializedCooperative",) EpilogueSchedule = ("TmaWarpSpecializedCooperative",)
TileSchedule = ("PersistentScheduler", "StreamKScheduler") TileSchedule = ("PersistentScheduler", "StreamKScheduler")
for act_tag in [ for act_tag in [
@@ -57,8 +56,18 @@ def get_candidate_configs(sm):
# ("relu", "ReLu"), # ("relu", "ReLu"),
# ("gelu", "GELU"), # ("gelu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend(
EpilogueSchedule, TileSchedule)]) [
(
hasbias,
act_tag,
tiles,
KernelSchedule,
EpilogueSchedule,
TileSchedule,
)
]
)
return candidate_configs return candidate_configs
@@ -66,16 +75,13 @@ def get_shape_str(tile_shape):
""" """
return tile_shape string. return tile_shape string.
""" """
blocks, clusters = [ blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, tile_schedule):
tile_schedule):
""" """
check the cutlass config valid. check the cutlass config valid.
""" """
@@ -304,13 +310,10 @@ def SubstituteTemplate(template, values_base):
SubstituteTemplate SubstituteTemplate
""" """
values = copy.deepcopy(values_base) values = copy.deepcopy(values_base)
if values.get("KernelSchedule" if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"] values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule" if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
) is not None and "Auto" in values["EpilogueSchedule"]: values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -329,8 +332,7 @@ def parse_args():
parse_args parse_args
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -346,14 +348,14 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_source_cu( def generate_source_cu(
inputs_type: (str), inputs_type: str,
outputs_type: (str), outputs_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
TileSchedule: (str), TileSchedule: str,
sm: str, sm: str,
): ):
""" """
@@ -369,8 +371,11 @@ def generate_source_cu(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid( if not check_config_valid(
tile_config, kernel_schedule, tile_config,
epilogue_schedule, tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -385,30 +390,32 @@ def generate_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_gemm_cus( def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
generate_launch_gemm_cus generate_launch_gemm_cus
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
TileSchedule: (str) = single_config[5] TileSchedule: str = single_config[5]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile_config in tiles: for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config) blocks, clusters = get_shape_str(tile_config)
@@ -418,19 +425,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(
tile_config,
kernel_schedule,
epilogue_schedule, epilogue_schedule,
tile_schedule): tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
head_all_code += SubstituteTemplate( head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -444,19 +451,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule, epilogue_schedule,
tile_schedule): tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
source_all_code = SubstituteTemplate( source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
LaunchGemmPart0, value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for output_type in outputs_type: for output_type in outputs_type:
@@ -476,16 +483,14 @@ def generate_launch_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
gemm_config_str = gemm_config_str.replace("<", "").replace( gemm_config_str = gemm_config_str.replace("<", "").replace(">", "")
">", "")
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu" f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
) )
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
@@ -495,19 +500,18 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu # generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str), def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str):
""" """
generate_dispatch_gemm_cu generate_dispatch_gemm_cu
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
TileSchedule: (str) = single_config[5] TileSchedule: str = single_config[5]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
@@ -530,9 +534,12 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule, epilogue_schedule,
tile_schedule): tile_schedule,
):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -554,18 +561,18 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule, epilogue_schedule,
tile_schedule): tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "tile_id": str(tile_id),
"tile_id": "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
str(tile_id),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
@@ -610,12 +617,17 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_gemm_cus( generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
outputs_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" file_name = (
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu") f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -24,7 +24,8 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
@@ -38,13 +39,13 @@ def get_candidate_tiles():
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]) ]
)
return base_configs return base_configs
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
max_stages):
""" """
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel. get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
""" """
@@ -299,8 +300,7 @@ def check_min_split_k(value):
""" """
ivalue = int(value) ivalue = int(value)
if ivalue > 1: if ivalue > 1:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
"Dual gemm split_k mode is not support.")
return ivalue return ivalue
@@ -310,8 +310,7 @@ def check_max_split_k(value):
""" """
ivalue = int(value) ivalue = int(value)
if ivalue > 1: if ivalue > 1:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
"Dual gemm split_k mode is not support..")
return ivalue return ivalue
@@ -320,8 +319,7 @@ def parse_args():
parse_args parse_args
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -421,8 +419,7 @@ def generate_dual_gemm_source_cu(
"hasbias": hasbias, "hasbias": hasbias,
"SM": sm, "SM": sm,
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
GemmSplitKDeclare, value_dict)
all_code += CommonTail all_code += CommonTail
return all_code return all_code
@@ -449,12 +446,12 @@ def generate_launch_dual_gemm_cus(
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h") head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -467,12 +464,12 @@ def generate_launch_dual_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -498,16 +495,14 @@ def generate_launch_dual_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict) # split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
# source_all_code += split_k_code # source_all_code += split_k_code
# source_all_code += LaunchGemmPart4 # source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -566,12 +561,12 @@ def generate_dispatch_dual_gemm_cu(
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -580,10 +575,12 @@ def generate_dispatch_dual_gemm_cu(
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
value_dict.update({ value_dict.update(
{
"min_split_k": str(min_split_k), "min_split_k": str(min_split_k),
"max_split_k": str(max_split_k), "max_split_k": str(max_split_k),
}) }
)
all_code += SubstituteTemplate(code_part6, value_dict) all_code += SubstituteTemplate(code_part6, value_dict)
return all_code return all_code
@@ -602,8 +599,7 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_dual_gemm_candidate_configs( fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"

View File

@@ -19,8 +19,7 @@ import re
def get_candidate_tiles(): def get_candidate_tiles():
""" """ """
"""
cta_shape = [ cta_shape = [
("<_64, _16, _128>"), ("<_64, _16, _128>"),
("<_64, _32, _128>"), ("<_64, _32, _128>"),
@@ -45,8 +44,7 @@ def get_candidate_tiles():
def get_dual_gemm_candidate_configs(sm): def get_dual_gemm_candidate_configs(sm):
""" """ """
"""
tiles = get_candidate_tiles() tiles = get_candidate_tiles()
candidate_configs = list() candidate_configs = list()
@@ -64,35 +62,27 @@ def get_dual_gemm_candidate_configs(sm):
("swiglu", "SiLu"), ("swiglu", "SiLu"),
("geglu", "GELU"), ("geglu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
EpilogueSchedule)])
return candidate_configs return candidate_configs
def get_shape_str(tile_shape): def get_shape_str(tile_shape):
""" """ """
""" blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule): def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
""" """ """
"""
blocks, clusters = get_shape_str(tile_shape) blocks, clusters = get_shape_str(tile_shape)
if int( if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule: if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False return False
if tile_shape[ if tile_shape[0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
return False return False
return True return True
@@ -302,8 +292,7 @@ bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
def SubstituteTemplate(template, values): def SubstituteTemplate(template, values):
""" """ """
"""
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -318,10 +307,8 @@ def SubstituteTemplate(template, values):
def parse_args(): def parse_args():
""" """ """
""" parser = argparse.ArgumentParser(description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser = argparse.ArgumentParser(
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
type=str, type=str,
@@ -336,17 +323,16 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_dual_gemm_source_cu( def generate_dual_gemm_source_cu(
inputs_type: (str), inputs_type: str,
biases_type: (str), biases_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
sm: str, sm: str,
): ):
""" """ """
"""
all_code = CommonHead all_code = CommonHead
for input_type in inputs_type: for input_type in inputs_type:
for bias_type in biases_type: for bias_type in biases_type:
@@ -354,9 +340,7 @@ def generate_dual_gemm_source_cu(
for tile_config in tiles: for tile_config in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
kernel_schedule,
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -370,28 +354,29 @@ def generate_dual_gemm_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_dual_gemm_cus( def generate_launch_dual_gemm_cus(
generate_dir: (str), inputs_type: (str), biases_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
""" biases_type: str,
""" fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile_config in tiles: for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config) blocks, clusters = get_shape_str(tile_config)
@@ -401,16 +386,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
head_all_code += SubstituteTemplate(LaunchGemmDeclare, head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -422,16 +405,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
source_all_code = SubstituteTemplate(LaunchGemmPart0, source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for bias_type in biases_type: for bias_type in biases_type:
@@ -450,14 +431,13 @@ def generate_launch_dual_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu" f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
) )
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
@@ -467,16 +447,14 @@ def generate_launch_dual_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act.cu # generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str), def generate_dispatch_dual_gemm_cu(inputs_type: str, biases_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str): """ """
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
@@ -500,8 +478,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for tile_shape in tiles: for tile_shape in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -520,8 +497,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
@@ -570,12 +546,15 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_dual_gemm_cus( generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
biases_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
biases_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
) )
all_code = generate_dispatch_dual_gemm_cu( all_code = generate_dispatch_dual_gemm_cu(
inputs_type, inputs_type,

View File

@@ -31,7 +31,8 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
[
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
@@ -43,13 +44,13 @@ def get_candidate_tiles():
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]) ]
)
return base_configs return base_configs
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
max_stages):
""" """
获取候选的gemm算子配置列表。 获取候选的gemm算子配置列表。
@@ -353,8 +354,7 @@ def parse_args():
代码参数解析 代码参数解析
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -448,8 +448,7 @@ def generate_source_cu(
"hasbias": hasbias, "hasbias": hasbias,
"SM": sm, "SM": sm,
} }
all_code += SubstituteTemplate(GemmSplitKDeclare, all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
value_dict)
all_code += CommonTail all_code += CommonTail
return all_code return all_code
@@ -473,9 +472,7 @@ def generate_launch_gemm_cus(
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h") head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -489,9 +486,7 @@ def generate_launch_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -517,17 +512,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict) split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
split_k_code += SubstituteTemplate(
LaunchGemmPart3, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
source_all_code += split_k_code source_all_code += split_k_code
source_all_code += LaunchGemmPart4 source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -581,9 +573,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4 all_code += code_part4
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -593,10 +583,12 @@ def generate_dispatch_gemm_cu(
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
value_dict.update({ value_dict.update(
{
"min_split_k": str(min_split_k), "min_split_k": str(min_split_k),
"max_split_k": str(max_split_k), "max_split_k": str(max_split_k),
}) }
)
all_code += SubstituteTemplate(code_part6, value_dict) all_code += SubstituteTemplate(code_part6, value_dict)
return all_code return all_code
@@ -614,9 +606,7 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_split_k, fuse_gemm_configs = get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
max_split_k, min_stages,
max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu" file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
all_code = generate_source_cu( all_code = generate_source_cu(
@@ -654,9 +644,7 @@ if __name__ == "__main__":
# hard code for act_tag # hard code for act_tag
file_name = ( file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -20,14 +20,14 @@ import re
def get_candidate_tiles(): def get_candidate_tiles():
""" """ """
"""
base_configs = [ base_configs = [
("<_64, _64, _128>", "<_1, _8, _1>"), ("<_64, _64, _128>", "<_1, _8, _1>"),
("<_64, _128, _128>", "<_2, _1, _1>"), ("<_64, _128, _128>", "<_2, _1, _1>"),
("<_128, _128, _128>", "<_2, _1, _1>"), ("<_128, _128, _128>", "<_2, _1, _1>"),
] ]
base_configs.extend([ base_configs.extend(
[
("<_64, _64, _128>", "<_1, _1, _1>"), ("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_1, _2, _1>"), ("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"), ("<_64, _64, _128>", "<_2, _1, _1>"),
@@ -50,14 +50,14 @@ def get_candidate_tiles():
# ("<_128, _128, _128>", "<_1, _1, _1>"), # ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"), # ("<_128, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"), # ("<_128, _128, _64>", "<_2, _2, _1>"),
]) ]
)
return base_configs return base_configs
def get_candidate_configs(sm): def get_candidate_configs(sm):
""" """ """
"""
tiles = get_candidate_tiles() tiles = get_candidate_tiles()
candidate_configs = list() candidate_configs = list()
@@ -73,36 +73,31 @@ def get_candidate_configs(sm):
("relu", "ReLu"), ("relu", "ReLu"),
("gelu", "GELU"), ("gelu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
EpilogueSchedule)])
return candidate_configs return candidate_configs
def get_shape_str(tile_shape): def get_shape_str(tile_shape):
""" """ """
""" blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule): def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
""" """ """
"""
blocks, clusters = get_shape_str(tile_shape) blocks, clusters = get_shape_str(tile_shape)
if int( if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule: if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False return False
if (tile_shape[0] == "<_256, _128, _128>" if (
tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in kernel_schedule and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule): and "Cooperative" not in epilogue_schedule
):
return False return False
return True return True
@@ -321,16 +316,12 @@ bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
def SubstituteTemplate(template, values_base): def SubstituteTemplate(template, values_base):
""" """ """
"""
values = copy.deepcopy(values_base) values = copy.deepcopy(values_base)
if values.get("KernelSchedule" if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"] values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule" if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
) is not None and "Auto" in values["EpilogueSchedule"]: values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -345,10 +336,8 @@ def SubstituteTemplate(template, values_base):
def parse_args(): def parse_args():
""" """ """
""" parser = argparse.ArgumentParser(description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser = argparse.ArgumentParser(
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
type=str, type=str,
@@ -363,17 +352,16 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_source_cu( def generate_source_cu(
inputs_type: (str), inputs_type: str,
outputs_type: (str), outputs_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
sm: str, sm: str,
): ):
""" """ """
"""
all_code = CommonHead all_code = CommonHead
for input_type in inputs_type: for input_type in inputs_type:
@@ -382,9 +370,7 @@ def generate_source_cu(
for tile_config in tiles: for tile_config in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
kernel_schedule,
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -398,25 +384,27 @@ def generate_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_gemm_cus( def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
""" outputs_type: str,
""" fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h") head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
@@ -426,16 +414,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
head_all_code += SubstituteTemplate(LaunchGemmDeclare, head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -447,16 +433,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
source_all_code = SubstituteTemplate(LaunchGemmPart0, source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for output_type in outputs_type: for output_type in outputs_type:
@@ -475,14 +459,14 @@ def generate_launch_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu") f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -491,17 +475,15 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu # generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str), def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str): """ """
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
@@ -524,8 +506,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for tile_shape in tiles: for tile_shape in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -544,8 +525,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
@@ -576,7 +556,8 @@ if __name__ == "__main__":
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu") f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_source_cu( all_code = generate_source_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,
@@ -594,8 +575,12 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_gemm_cus( generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
outputs_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu" file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"

View File

@@ -30,7 +30,8 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
@@ -45,7 +46,8 @@ def get_candidate_tiles():
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"), ("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]) ]
)
return base_configs return base_configs
@@ -278,8 +280,7 @@ def parse_args():
代码参数解析 代码参数解析
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -370,13 +371,10 @@ def generate_launch_gemm_cus(
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。 - dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。
""" """
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, "launch_visitor_gemm_fused_kernel.h")
"launch_visitor_gemm_fused_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -390,9 +388,7 @@ def generate_launch_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -415,14 +411,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu") f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu",
)
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -485,9 +481,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4 all_code += code_part4
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -512,10 +506,11 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_stages, fuse_gemm_configs = get_candidate_configs(sm, min_stages, max_stages)
max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu" file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
)
all_code = generate_source_cu( all_code = generate_source_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,
@@ -544,9 +539,7 @@ if __name__ == "__main__":
sm_dict[sm], sm_dict[sm],
) )
file_name = ( file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -30,8 +30,7 @@ current_file = Path(__file__).resolve()
base_dir = current_file.parent base_dir = current_file.parent
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
XDNN_LIB_DIR):
""" """
build xpu plugin build xpu plugin
""" """
@@ -49,7 +48,10 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 删除指定目录 # 删除指定目录
dirs_to_remove = [ dirs_to_remove = [
"dist", "fastdeploy_ops.egg-info", "build", "plugin/build" "dist",
"fastdeploy_ops.egg-info",
"build",
"plugin/build",
] ]
for dir_name in dirs_to_remove: for dir_name in dirs_to_remove:
if os.path.exists(dir_name): if os.path.exists(dir_name):
@@ -58,8 +60,7 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 在 plugin 目录中执行构建脚本 # 在 plugin 目录中执行构建脚本
plugin_dir = "plugin" plugin_dir = "plugin"
build_script = os.path.join(current_working_directory, plugin_dir, build_script = os.path.join(current_working_directory, plugin_dir, "build.sh")
"build.sh")
print("build_script: ", build_script) print("build_script: ", build_script)
@@ -74,14 +75,16 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 执行构建脚本 # 执行构建脚本
try: try:
print("Running build script...") print("Running build script...")
subprocess.run([build_script], subprocess.run(
[build_script],
check=True, check=True,
cwd=os.path.join(current_working_directory, plugin_dir)) cwd=os.path.join(current_working_directory, plugin_dir),
)
print("Build completed successfully.") print("Build completed successfully.")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Build failed with error: {e}") print(f"Build failed with error: {e}")
except Exception as e: except Exception as e:
print(f"Unexpected error: {str(e)}") print(f"Unexpected error: {e!s}")
def xpu_setup_ops(): def xpu_setup_ops():
@@ -124,17 +127,14 @@ def xpu_setup_ops():
XVLLM_PATH = os.getenv("XVLLM_PATH") XVLLM_PATH = os.getenv("XVLLM_PATH")
assert XVLLM_PATH is not None, "XVLLM_PATH is not set." assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include") XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", "libapiinfer.so")
"libapiinfer.so")
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so") XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include") XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", "libxft_blocks.so")
"libxft_blocks.so")
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so") XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
# build plugin # build plugin
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
XDNN_LIB_DIR)
ops = [ ops = [
# custom ops # custom ops
@@ -152,7 +152,6 @@ def xpu_setup_ops():
"./ops/block_attn.cc", "./ops/block_attn.cc",
"./ops/moe_layer.cc", "./ops/moe_layer.cc",
"./ops/weight_quantize_xpu.cc", "./ops/weight_quantize_xpu.cc",
# device manage ops # device manage ops
"./ops/device/get_context_gm_max_mem_demand.cc", "./ops/device/get_context_gm_max_mem_demand.cc",
"./ops/device/get_free_global_memory.cc", "./ops/device/get_free_global_memory.cc",

View File

@@ -29,7 +29,7 @@ for i in range(bs):
ids_len = seq_lens[i, 0] ids_len = seq_lens[i, 0]
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64") input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset( (x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
paddle.to_tensor(input_ids), paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset), paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num), paddle.to_tensor(token_num),
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
print("cu_seqlens_q:\n", cu_seqlens_q) print("cu_seqlens_q:\n", cu_seqlens_q)
print("cu_seqlens_k:\n", cu_seqlens_k) print("cu_seqlens_k:\n", cu_seqlens_k)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
"int64")
ref_cum_offsets_out = np.array([0, 6, 13], "int32") ref_cum_offsets_out = np.array([0, 6, 13], "int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
"int32")
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32") ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32") ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
assert sum(ref_x_remove_padding - assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
x_remove_padding) == 0, 'Check x_remove_padding failed.' assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
assert sum(ref_cum_offsets_out - assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
cum_offsets_out) == 0, 'Check cum_offsets_out failed.' assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
assert sum(ref_padding_offset - assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
padding_offset) == 0, 'Check padding_offset failed.'
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'

View File

@@ -21,10 +21,15 @@ paddle.seed(2023)
pre_ids = paddle.to_tensor( pre_ids = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], [[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64") "int64",
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1], )
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]], logits = paddle.to_tensor(
"float32") [
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
],
"float32",
)
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32") frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
presence_scores = paddle.to_tensor([0.0, 0.0], "float32") presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
@@ -88,78 +93,536 @@ ref_logits = np.array(
) )
diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits) print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.' assert diff_logits < 1e-6, "Check failed."
pre_ids = paddle.to_tensor( pre_ids = paddle.to_tensor(
[[ [
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8, [
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4, 2,
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6, 3,
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8, 3,
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4, 5,
7, 5, 5, 7, 6, 9, 3, 9 8,
9,
3,
9,
1,
8,
9,
2,
3,
8,
8,
9,
9,
1,
4,
2,
6,
2,
6,
8,
7,
2,
2,
3,
8,
1,
5,
7,
9,
2,
2,
9,
1,
4,
9,
8,
5,
8,
5,
7,
3,
6,
4,
4,
9,
9,
8,
5,
5,
2,
2,
9,
4,
8,
1,
9,
6,
9,
2,
2,
7,
2,
2,
9,
4,
6,
4,
6,
1,
4,
1,
9,
1,
8,
8,
5,
7,
9,
4,
2,
5,
1,
1,
4,
1,
5,
5,
4,
4,
2,
1,
8,
7,
1,
2,
9,
6,
7,
9,
6,
7,
7,
4,
9,
9,
7,
5,
1,
8,
9,
8,
8,
5,
4,
6,
4,
7,
5,
5,
7,
6,
9,
3,
9,
], ],
[ [
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1, 7,
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1, 8,
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8, 1,
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6, 3,
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6, 1,
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8 7,
]]) 6,
3,
5,
3,
8,
3,
1,
9,
7,
1,
1,
9,
5,
4,
9,
6,
1,
9,
3,
8,
3,
9,
9,
6,
4,
2,
8,
5,
3,
1,
6,
9,
1,
3,
9,
8,
1,
7,
5,
1,
5,
1,
8,
7,
4,
5,
9,
8,
7,
4,
7,
3,
6,
4,
6,
6,
5,
5,
2,
9,
9,
5,
8,
8,
4,
8,
2,
8,
1,
3,
9,
1,
8,
5,
8,
3,
8,
8,
2,
7,
3,
7,
5,
7,
2,
6,
3,
5,
1,
4,
6,
1,
9,
8,
2,
2,
3,
6,
7,
6,
2,
6,
5,
1,
5,
6,
2,
1,
6,
4,
7,
7,
3,
8,
5,
1,
9,
1,
2,
8,
6,
8,
],
]
)
logits = paddle.to_tensor( logits = paddle.to_tensor(
[[ [
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748, [
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807, 0.16274983,
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218, 0.61470598,
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679, 0.94366980,
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888, 0.82005417,
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911, 0.50752640,
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448, 0.38316748,
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415, 0.92648441,
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104, 0.24050158,
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064, 0.05461595,
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672, 0.42218581,
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840, 0.36270225,
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613, 0.15464807,
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254, 0.13614719,
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189, 0.67509544,
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407, 0.40315166,
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923, 0.10671722,
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550, 0.24832056,
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222, 0.76091218,
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845, 0.11598995,
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070, 0.10962527,
0.98481447, 0.77367836 0.04688513,
0.81536716,
0.72259802,
0.60476679,
0.16701800,
0.84160781,
0.79649884,
0.78021604,
0.75329530,
0.98587888,
0.13421868,
0.16027625,
0.15269397,
0.06228730,
0.73856270,
0.34721911,
0.73683006,
0.78178608,
0.32068327,
0.79906309,
0.44214272,
0.63330448,
0.08016958,
0.63367140,
0.19788943,
0.55346787,
0.11142531,
0.90518415,
0.21236691,
0.81587470,
0.83752930,
0.70979482,
0.35684183,
0.28715104,
0.87162822,
0.17679396,
0.98725849,
0.76129991,
0.04090235,
0.37181064,
0.63317049,
0.24689502,
0.21126501,
0.57617670,
0.74346697,
0.40613672,
0.56907010,
0.68556929,
0.29032683,
0.17866278,
0.35165095,
0.97015840,
0.70785582,
0.54259878,
0.14712237,
0.90483177,
0.02094105,
0.36411613,
0.02495066,
0.88874054,
0.88895452,
0.86216462,
0.58062190,
0.95583254,
0.20553111,
0.29870346,
0.69652933,
0.36861244,
0.85316223,
0.50240189,
0.17566244,
0.61080140,
0.88203174,
0.98675215,
0.24344546,
0.17213407,
0.78160852,
0.25165486,
0.48188508,
0.82812423,
0.10199814,
0.90475923,
0.66907483,
0.71910626,
0.40660757,
0.59460294,
0.70212913,
0.90841550,
0.00329034,
0.11290466,
0.89654654,
0.69114941,
0.29473618,
0.62027222,
0.37333879,
0.98911142,
0.46510187,
0.65914583,
0.73022646,
0.12790845,
0.12817244,
0.43015456,
0.75011456,
0.43562204,
0.48086026,
0.75587070,
0.98481447,
0.77367836,
], ],
[ [
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417, 0.12336024,
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683, 0.74152875,
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915, 0.09191196,
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503, 0.99301219,
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967, 0.44764417,
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547, 0.01848883,
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396, 0.78326035,
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451, 0.99228370,
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258, 0.81447607,
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694, 0.02627683,
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542, 0.51033205,
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293, 0.98703283,
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045, 0.15247856,
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974, 0.77640921,
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888, 0.60799915,
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515, 0.87518770,
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780, 0.76818430,
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544, 0.86542630,
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529, 0.31795895,
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276, 0.04829503,
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621, 0.85567141,
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523, 0.30271924,
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239, 0.67515039,
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721, 0.59728831,
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736, 0.78710967,
0.87317258, 0.88229448, 0.79217333 0.75111693,
]]) 0.56837374,
0.49085775,
0.91510201,
0.59545547,
0.99482232,
0.59036905,
0.58267909,
0.28770933,
0.53237396,
0.95318258,
0.93987304,
0.61142951,
0.26737869,
0.52285451,
0.03479086,
0.61631846,
0.66777998,
0.15736090,
0.00447258,
0.37035006,
0.15281211,
0.95372260,
0.25963321,
0.61036694,
0.15020694,
0.19171195,
0.55252832,
0.00391038,
0.31052542,
0.96495175,
0.42586124,
0.05630261,
0.99728668,
0.01856293,
0.83201504,
0.10701843,
0.56434178,
0.38009524,
0.51095045,
0.13202040,
0.07133843,
0.75313550,
0.17111187,
0.80716974,
0.00172165,
0.83906764,
0.73240769,
0.85843354,
0.11042888,
0.07912333,
0.33689004,
0.22334915,
0.59059596,
0.52789515,
0.29831955,
0.39515004,
0.55602801,
0.83818001,
0.05865780,
0.25654668,
0.76624149,
0.35190639,
0.04158346,
0.59157544,
0.30779791,
0.94609004,
0.10759670,
0.65575141,
0.37828529,
0.29571742,
0.76361233,
0.72476572,
0.18568406,
0.85430276,
0.02057583,
0.76195669,
0.65507215,
0.69129735,
0.25084621,
0.75223947,
0.06064088,
0.20287007,
0.35887691,
0.75043523,
0.47575447,
0.40021798,
0.44464844,
0.67975360,
0.40443239,
0.71052992,
0.21782248,
0.50568426,
0.89037591,
0.06661721,
0.28788096,
0.70773387,
0.42428264,
0.80419677,
0.42710736,
0.87317258,
0.88229448,
0.79217333,
],
]
)
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) # pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) # logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
print("eos_token_id\n", eos_token_id) print("eos_token_id\n", eos_token_id)
ref_logits = np.array( ref_logits = np.array(
[[ [
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280, [
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450, -10000000000.0,
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113, -10000000000.0,
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603, 1.88733959,
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060, 1.64010835,
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541, 1.01505280,
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545, 0.76633495,
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062, 1.85296881,
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366, 0.48100317,
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470, 0.10923190,
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395, 0.84437162,
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191, 0.72540450,
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210, 0.30929613,
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380, 0.27229437,
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446, 1.35019088,
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091, 0.80630332,
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629, 0.21343444,
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825, 0.49664113,
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235, 1.52182436,
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291, 0.23197991,
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053, 0.21925054,
1.51174140, 1.96962893, 1.54735672 0.09377026,
1.63073432,
1.44519603,
1.20953357,
0.33403599,
1.68321562,
1.59299767,
1.56043208,
1.50659060,
1.97175777,
0.26843736,
0.32055250,
0.30538794,
0.12457460,
1.47712541,
0.69443822,
1.47366011,
1.56357217,
0.64136654,
1.59812617,
0.88428545,
1.26660895,
0.16033916,
1.26734281,
0.39577886,
1.10693574,
0.22285062,
1.81036830,
0.42473382,
1.63174939,
1.67505860,
1.41958964,
0.71368366,
0.57430208,
1.74325645,
0.35358793,
1.97451699,
1.52259982,
0.08180470,
0.74362129,
1.26634097,
0.49379003,
0.42253003,
1.15235341,
1.48693395,
0.81227344,
1.13814020,
1.37113857,
0.58065367,
0.35732555,
0.70330191,
1.94031680,
1.41571164,
1.08519757,
0.29424474,
1.80966353,
0.04188210,
0.72823226,
0.04990132,
1.77748108,
1.77790904,
1.72432923,
1.16124380,
1.91166508,
0.41106221,
0.59740692,
1.39305866,
0.73722488,
1.70632446,
1.00480378,
0.35132489,
1.22160280,
1.76406348,
1.97350430,
0.48689091,
0.34426814,
1.56321704,
0.50330973,
0.96377015,
1.65624845,
0.20399629,
1.80951846,
1.33814967,
1.43821251,
0.81321514,
1.18920588,
1.40425825,
1.81683099,
0.00658068,
0.22580932,
1.79309309,
1.38229883,
0.58947235,
1.24054444,
0.74667758,
1.97822285,
0.93020374,
1.31829166,
1.46045291,
0.25581691,
0.25634488,
0.86030912,
1.50022912,
0.87124407,
0.96172053,
1.51174140,
1.96962893,
1.54735672,
], ],
[ [
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670, -10000000000.0,
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819, -10000000000.0,
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078, -40000.0,
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562, 3.97204876,
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773, 1.79057670,
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929, 0.07395532,
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031, 3.13304138,
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343, 3.96913481,
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025, 3.25790429,
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775, -40000.0,
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701, 2.04132819,
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015, 3.94813132,
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160, 0.60991424,
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660, 3.10563684,
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333, 2.43199658,
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819, 3.50075078,
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670, 3.07273722,
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164, 3.46170521,
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967, 1.27183580,
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332, 0.19318011,
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786, 3.42268562,
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788, 1.21087694,
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969, 2.70060158,
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383, 2.38915324,
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032, 3.14843869,
3.52917790, 3.16869330 3.00446773,
]], 2.27349496,
1.96343100,
3.66040802,
2.38182187,
3.97928929,
2.36147618,
2.33071637,
1.15083730,
2.12949586,
3.81273031,
3.75949216,
2.44571805,
1.06951475,
2.09141803,
0.13916343,
2.46527386,
2.67111993,
0.62944359,
0.01789032,
1.48140025,
0.61124843,
3.81489038,
1.03853285,
2.44146776,
0.60082775,
0.76684779,
2.21011329,
0.01564152,
1.24210167,
3.85980701,
1.70344496,
0.22521044,
3.98914671,
0.07425172,
3.32806015,
0.42807373,
2.25736713,
1.52038097,
2.04380178,
0.52808160,
0.28535372,
3.01254201,
0.68444747,
3.22867894,
0.00688660,
3.35627055,
2.92963076,
3.43373418,
0.44171551,
0.31649333,
1.34756017,
0.89339662,
2.36238384,
2.11158061,
1.19327819,
1.58060014,
2.22411203,
3.35272002,
0.23463120,
1.02618670,
3.06496596,
1.40762556,
0.16633384,
2.36630177,
1.23119164,
3.78436017,
0.43038681,
2.62300563,
1.51314116,
1.18286967,
3.05444932,
2.89906287,
0.74273622,
3.41721106,
0.08230332,
3.04782677,
2.62028861,
2.76518941,
1.00338483,
3.00895786,
0.24256352,
0.81148028,
1.43550766,
3.00174093,
1.90301788,
1.60087192,
1.77859378,
2.71901441,
1.61772954,
2.84211969,
0.87128991,
2.02273703,
3.56150365,
0.26646885,
1.15152383,
2.83093548,
1.69713056,
3.21678710,
1.70842946,
3.49269032,
3.52917790,
3.16869330,
],
],
"float32", "float32",
) )
diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits) print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.' assert diff_logits < 1e-6, "Check failed."

View File

@@ -21,19 +21,30 @@ paddle.seed(2023)
pre_ids_all = paddle.to_tensor( pre_ids_all = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], [[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64") "int64",
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1], )
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]], input_ids = paddle.to_tensor(
"int64") [
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
],
"int64",
)
seq_lens_this_time = paddle.to_tensor([1, 1], "int32") seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
seq_lens_encoder = paddle.to_tensor([1, 1], "int32") seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
seq_lens_decoder = paddle.to_tensor([1, 1], "int32") seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
step_idx = paddle.to_tensor([1, 1], "int64") step_idx = paddle.to_tensor([1, 1], "int64")
stop_flags = paddle.to_tensor([0, 1], "bool") stop_flags = paddle.to_tensor([0, 1], "bool")
print("pre_ids_all\n", pre_ids_all) print("pre_ids_all\n", pre_ids_all)
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time, set_value_by_flags_and_idx(
seq_lens_encoder, seq_lens_decoder, step_idx, pre_ids_all,
stop_flags) input_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
)
print("pre_ids_all\n", pre_ids_all) print("pre_ids_all\n", pre_ids_all)
print("input_ids\n", input_ids) print("input_ids\n", input_ids)
print("seq_lens_this_time\n", seq_lens_this_time) print("seq_lens_this_time\n", seq_lens_this_time)
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
) )
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy())) diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
print("diff_pre_ids_all\n", diff_pre_ids_all) print("diff_pre_ids_all\n", diff_pre_ids_all)
assert diff_pre_ids_all == 0, 'Check failed.' assert diff_pre_ids_all == 0, "Check failed."

View File

@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio)) free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32") free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(max_block_num - 1, free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
max_block_num - free_list_len - 1,
-1,
dtype="int32")
encoder_block_lens = np.zeros([max_bs], "int32") encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32") used_list_len = np.zeros([max_bs], "int32")
@@ -53,19 +50,15 @@ encoder_block_id = 0
for i in range(bs): for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
1) // block_size - enc_block_num
used_list_len[i] = dec_block_num used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange( block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num encoder_block_id += enc_block_num
if dec_block_num > 0: if dec_block_num > 0:
block_tables[ block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
i, enc_block_num:enc_block_num + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
dec_block_num] = free_list[free_list_len[0] - 1 - ]
dec_block_num:free_list_len[0] - 1] free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
1] = -1
free_list_len[0] -= dec_block_num free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0 assert free_list_len[0] >= 0
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
# print("step_idx: ", step_idx) # print("step_idx: ", step_idx)
# print("next_tokens: ", next_tokens) # print("next_tokens: ", next_tokens)
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder, step_paddle(
seq_lens_encoder, seq_lens_decoder, block_tables, stop_flags,
encoder_block_lens, is_block_step, step_block_list, step_lens, seq_lens_this_time,
recover_block_list, recover_lens, need_block_list, need_block_len, ori_seq_lens_encoder,
used_list_len, free_list, free_list_len, input_ids, pre_ids, seq_lens_encoder,
step_idx, next_tokens, first_token_ids, block_size, seq_lens_decoder,
encoder_decoder_block_num) block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
block_size,
encoder_decoder_block_num,
)
print("-" * 50 + "after step op" + "-" * 50) print("-" * 50 + "after step op" + "-" * 50)
print("stop_flags: ", stop_flags) print("stop_flags: ", stop_flags)

View File

@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
False)
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array( ref_topk_ids = np.array(
[ [
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20, 0,
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39, 0,
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59, 2,
60, 0, 0, 63 3,
-1,
0,
0,
0,
0,
9,
10,
0,
12,
0,
-1,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
-1,
29,
30,
31,
0,
0,
0,
-1,
-1,
37,
38,
39,
-1,
-1,
0,
0,
0,
0,
46,
-1,
0,
49,
50,
0,
52,
53,
0,
-1,
0,
57,
-1,
59,
60,
0,
0,
63,
], ],
"int64", "int64",
) )
ref_next_tokens = np.array( ref_next_tokens = np.array(
[ [
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20, 0,
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0, 0,
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0, 2,
0, 63 3,
0,
0,
0,
0,
0,
9,
10,
0,
12,
0,
0,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
0,
29,
30,
31,
0,
0,
0,
0,
0,
37,
38,
39,
0,
0,
0,
0,
0,
0,
46,
0,
0,
49,
50,
0,
52,
53,
0,
0,
0,
57,
0,
59,
60,
0,
0,
63,
], ],
"int64", "int64",
) )
ref_stop_flags = np.array( ref_stop_flags = np.array(
[ [
True, True, True, True, True, True, True, True, True, False, False, True,
True, False, True, True, False, False, True, False, False, False, True, True,
False, False, True, False, False, False, True, False, False, False, True,
True, True, True, True, True, False, False, False, True, True, True, True,
True, True, True, False, True, True, False, False, True, False, False, True,
True, True, True, False, True, False, False, True, True, False True,
True,
True,
True,
False,
False,
True,
False,
True,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
False,
True,
True,
True,
True,
True,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
True,
True,
False,
False,
True,
False,
False,
True,
True,
True,
False,
True,
False,
False,
True,
True,
False,
], ],
"bool", "bool",
) )
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids) print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.' assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens) print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.' assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum( diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags) print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.' assert diff_stop_flags == 0, "Check failed."
# test beam_search=True # test beam_search=True
topk_ids = paddle.arange(0, bs, dtype="int64") topk_ids = paddle.arange(0, bs, dtype="int64")
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
True)
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array( ref_topk_ids = np.array(
[ [
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19, 0,
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0, 1,
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 2,
-1, 60, 61, -1, 63 3,
4,
0,
6,
7,
-1,
9,
10,
0,
-1,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
-1,
-1,
28,
29,
0,
0,
-1,
33,
34,
35,
36,
37,
0,
-1,
0,
41,
-1,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
-1,
60,
61,
-1,
63,
], ],
"int64", "int64",
) )
ref_next_tokens = np.array( ref_next_tokens = np.array(
[ [
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20, 0,
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0, 1,
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61, 2,
0, 63 3,
4,
0,
6,
7,
0,
9,
10,
0,
0,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
0,
0,
28,
29,
0,
0,
0,
33,
34,
35,
36,
37,
0,
0,
0,
41,
0,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
0,
60,
61,
0,
63,
], ],
"int64", "int64",
) )
ref_stop_flags = np.array( ref_stop_flags = np.array(
[ [
False, False, False, False, False, True, False, False, True, False, False,
False, True, True, False, False, False, True, False, False, False, False,
False, True, False, False, False, False, True, True, False, False, False,
True, True, True, False, False, False, False, False, True, True, True, False,
False, True, True, False, False, False, True, True, False, True, True, False,
True, False, True, True, True, True, False, True, False, False, True, True,
False False,
False,
True,
False,
False,
True,
True,
False,
False,
False,
True,
False,
False,
False,
False,
True,
False,
False,
False,
False,
True,
True,
False,
False,
True,
True,
True,
False,
False,
False,
False,
False,
True,
True,
True,
False,
True,
True,
False,
False,
False,
True,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
False,
True,
False,
False,
True,
False,
], ],
"bool", "bool",
) )
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids) print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.' assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens) print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.' assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum( diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags) print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.' assert diff_stop_flags == 0, "Check failed."

View File

@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens) print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step) print("is_block_step:\n", is_block_step)
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder, update_inputs(
seq_lens_decoder, input_ids, stop_nums, next_tokens, stop_flags,
is_block_step) not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
next_tokens,
is_block_step,
)
print("-" * 50) print("-" * 50)
print("stop_flags:\n", stop_flags) print("stop_flags:\n", stop_flags)
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens) print("next_tokens:\n", next_tokens)
ref_not_need_stop_out = np.array([True]) ref_not_need_stop_out = np.array([True])
ref_seq_lens_this_time_out = np.array([ ref_seq_lens_this_time_out = np.array(
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, [
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1 0,
], "int32") 0,
ref_seq_lens_encoder_out = np.array([ 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 1,
], "int32") 0,
ref_seq_lens_decoder_out = np.array([ 1,
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0, 1,
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0, 1,
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 0,
], "int32") 1,
input_ids_np[:, 0] = np.array([ 1,
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5, 0,
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7, 0,
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8 0,
], "int64") 0,
0,
0,
0,
1,
1,
0,
1,
1,
0,
1,
1,
0,
0,
0,
1,
1,
0,
1,
0,
0,
1,
0,
1,
0,
0,
1,
0,
0,
1,
1,
1,
],
"int32",
)
ref_seq_lens_encoder_out = np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
ref_seq_lens_decoder_out = np.array(
[
0,
0,
2,
0,
0,
6,
0,
8,
8,
10,
0,
12,
12,
0,
0,
0,
0,
0,
0,
0,
20,
22,
0,
24,
24,
0,
26,
28,
0,
0,
0,
32,
32,
0,
34,
0,
0,
38,
0,
40,
0,
0,
42,
0,
0,
46,
46,
48,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
input_ids_np[:, 0] = np.array(
[
6,
5,
9,
8,
6,
2,
8,
1,
3,
1,
3,
6,
9,
8,
1,
9,
1,
8,
8,
6,
7,
6,
5,
3,
5,
9,
3,
6,
3,
9,
8,
8,
8,
8,
4,
8,
7,
4,
2,
3,
5,
8,
4,
2,
5,
6,
8,
9,
6,
7,
4,
2,
4,
6,
2,
3,
4,
9,
7,
2,
1,
8,
7,
8,
],
"int64",
)
assert not_need_stop.numpy( assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
) == ref_not_need_stop_out, 'Check not_need_stop failed.' assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
assert np.all(seq_lens_this_time.numpy() == assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.' assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
assert np.all(seq_lens_encoder.numpy() == assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
assert np.all(seq_lens_decoder.numpy() ==
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'

View File

@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
weight = np.transpose(weight_np, [1, 0]) # n,k weight = np.transpose(weight_np, [1, 0]) # n,k
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1 max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | ( quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1) weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32) return quanted_weight, weight_scales.astype(np.float32)
def np_quant_weight(weight_np, algo='weight_only_int8'): def np_quant_weight(weight_np, algo="weight_only_int8"):
assert weight_np.dtype == np.float32 assert weight_np.dtype == np.float32
if algo == 'weight_only_int4': if algo == "weight_only_int4":
return np_quant_weight_int4(weight_np) return np_quant_weight_int4(weight_np)
weight = np.transpose(weight_np, [1, 0]) weight = np.transpose(weight_np, [1, 0])
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
def int8_to_bin(value): def int8_to_bin(value):
if not -128 <= value <= 127: if not -128 <= value <= 127:
raise ValueError("int8 值必须在 -128 到 127 之间") raise ValueError("int8 值必须在 -128 到 127 之间")
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零 return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
# 1) preparation # 1) preparation
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
qw_np, wscale_np = np_quant_weight(w_np, algo) qw_np, wscale_np = np_quant_weight(w_np, algo)
# 3) xpu calculation # 3) xpu calculation
dtype = 'float32' dtype = "float32"
x_pd = paddle.to_tensor(w_np, dtype=dtype) x_pd = paddle.to_tensor(w_np, dtype=dtype)
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1) qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
qw_pd_trans = paddle.transpose(qw_pd, [1, 0]) qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
# comparation # comparation
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}") print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}") print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
print( print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}" print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
) sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
)
sum_diff = np.sum(
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"sum_diff: {sum_diff}") print(f"sum_diff: {sum_diff}")

View File

@@ -72,7 +72,8 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
### Multi-machine Disaggregated Deployment ### Multi-machine Disaggregated Deployment
#### Prerequisite: Redis #### Prerequisite: Redis
- Installation via `conda` * Installation via `conda`
```bash ```bash
# Install # Install
conda install redis conda install redis
@@ -80,7 +81,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 & nohup redis-server > redis.log 2>&1 &
``` ```
- Installation via `apt` * Installation via `apt`
```bash ```bash
# Install # Install
sudo apt install redis-server -y sudo apt install redis-server -y
@@ -88,7 +90,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server sudo systemctl start redis-server
``` ```
- Installation via `yum` * Installation via `yum`
```bash ```bash
# Install # Install
sudo yum install redis -y sudo yum install redis -y

View File

@@ -38,6 +38,7 @@ conda install redis
# Launch # Launch
nohup redis-server > redis.log 2>&1 & nohup redis-server > redis.log 2>&1 &
``` ```
### apt installation (Debian/Ubuntu) ### apt installation (Debian/Ubuntu)
```bash ```bash
@@ -57,6 +58,7 @@ sudo systemctl start redis
``` ```
## Launching FastDeploy ## Launching FastDeploy
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \ --port 8801 \
@@ -72,6 +74,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \ --scheduler-min-load_score 3 \
--scheduler-load-shards-num 1 --scheduler-load-shards-num 1
``` ```
[Scheduler Launching Parameter](../online_serving/scheduler.md) [Scheduler Launching Parameter](../online_serving/scheduler.md)
### Deployment notes: ### Deployment notes:

View File

@@ -20,6 +20,7 @@ For reasoning models, the length of the reasoning content can be controlled via
### Quick Start ### Quick Start
When launching the model service, specify the parser name using the `--reasoning-parser` argument. When launching the model service, specify the parser name using the `--reasoning-parser` argument.
This parser will process the model's output and extract the `reasoning_content` field. This parser will process the model's output and extract the `reasoning_content` field.
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \ --model /path/to/your/model \
@@ -29,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \ --quantization wint4 \
--reasoning-parser ernie-45-vl --reasoning-parser ernie-45-vl
``` ```
Next, make a request to the model that should return the reasoning content in the response. Next, make a request to the model that should return the reasoning content in the response.
```bash ```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@@ -43,10 +46,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
"metadata": {"enable_thinking": true} "metadata": {"enable_thinking": true}
}' }'
``` ```
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself. The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
### Streaming chat completions ### Streaming chat completions
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks` Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
```python ```python
from openai import OpenAI from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server. # Set OpenAI's API key and API base to use vLLM's API server.

View File

@@ -132,4 +132,3 @@ Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
```json ```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}} {"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
``` ```

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
``` ```
## 2. Start service ## 2. Start service
```bash ```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN" export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8 --gpu-memory-utilization=0.8
``` ```
#### Send requests ### Send requests
Send requests using either curl or Python Send requests using either curl or Python

View File

@@ -19,13 +19,15 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
## Container Preparation ## Container Preparation
1. Start Container 1. Start Container
```bash ```bash
docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker exec -it paddle_infer bash docker exec -it paddle_infer bash
``` ```
/home/paddle contains the model files, *.whl packages, and scripts. /home/paddle contains the model files, *.whl packages, and scripts.
2. Install packages 1. Install packages
```bash ```bash
pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
@@ -38,6 +40,7 @@ pip3 install fastdeploy_iluvatar_gpu -i https://www.paddlepaddle.org.cn/packages
script list below: script list below:
`run_demo.sh`: `run_demo.sh`:
```bash ```bash
#!/bin/bash #!/bin/bash
export PADDLE_XCCL_BACKEND=iluvatar_gpu export PADDLE_XCCL_BACKEND=iluvatar_gpu
@@ -78,7 +81,9 @@ for output in outputs:
```bash ```bash
./run_demo.sh ./run_demo.sh
``` ```
The following logs will be printed: Loading the model took approximately 74 seconds, and running the demo took approximately 240 seconds. The following logs will be printed: Loading the model took approximately 74 seconds, and running the demo took approximately 240 seconds.
``` ```
/usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md /usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
warnings.warn(warning_message) warnings.warn(warning_message)

View File

@@ -36,6 +36,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
] ]
}' }'
``` ```
Here's an example curl command demonstrating how to include the logprobs parameter in a user request: Here's an example curl command demonstrating how to include the logprobs parameter in a user request:
```bash ```bash
@@ -49,6 +50,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
``` ```
Here is an example of sending a user request using a Python script: Here is an example of sending a user request using a Python script:
```python ```python
import openai import openai
host = "0.0.0.0" host = "0.0.0.0"

View File

@@ -45,7 +45,6 @@ When using FastDeploy to deploy models (including offline inference and service
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel | | ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
During FastDeploy inference, GPU memory is occupied by ```model weights```, ```preallocated KVCache blocks``` and ```model computation intermediate activation values```. The preallocated KVCache blocks are determined by ```num_gpu_blocks_override```, with ```block_size``` (default: 64) as its unit, meaning one block can store KVCache for 64 Tokens. During FastDeploy inference, GPU memory is occupied by ```model weights```, ```preallocated KVCache blocks``` and ```model computation intermediate activation values```. The preallocated KVCache blocks are determined by ```num_gpu_blocks_override```, with ```block_size``` (default: 64) as its unit, meaning one block can store KVCache for 64 Tokens.
@@ -80,18 +79,18 @@ Currently, only user configuration of the following parameters is supported
CudaGrpah can be enabled by setting `--use-cudagraph` or `--graph-optimization-config '{"use_cudagraph":true}'`. Using two different methods to set the use graph simultaneously may cause conflicts. CudaGrpah can be enabled by setting `--use-cudagraph` or `--graph-optimization-config '{"use_cudagraph":true}'`. Using two different methods to set the use graph simultaneously may cause conflicts.
The `graph_opt_level` parameter within `--graph-optimization-config` is used to configure the graph optimization level, with the following available options: The `graph_opt_level` parameter within `--graph-optimization-config` is used to configure the graph optimization level, with the following available options:
- `0`: Use Dynamic compute graph, default to 0 - `0`: Use Dynamic compute graph, default to 0
- `1`: Use Static compute graph, during the initialization phase, Paddle API will be used to convert the dynamic image into a static image - `1`: Use Static compute graph, during the initialization phase, Paddle API will be used to convert the dynamic image into a static image
- `2`: Base on Static compute graph, use the complier(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize - `2`: Base on Static compute graph, use the complier(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize
In general, static graphs have lower Kernel Launch overhead than dynamic graphs, and it is recommended to use static graphs. In general, static graphs have lower Kernel Launch overhead than dynamic graphs, and it is recommended to use static graphs.
For adapted models, FastDeploy's CudaGraph * * can support both dynamic and static graphs * * simultaneously. For adapted models, FastDeploy's CudaGraph *can support both dynamic and static graphs* simultaneously.
When CudaGraph is enabled in the default configuration, a list of Batch Sizes that CudaGraph needs to capture will be automatically set based on the 'max_num_deqs' parameter. The logic for generating the list of Batch Sizes that need to be captured is as follows When CudaGraph is enabled in the default configuration, a list of Batch Sizes that CudaGraph needs to capture will be automatically set based on the 'max_num_deqs' parameter. The logic for generating the list of Batch Sizes that need to be captured is as follows
1. Generate a candidate list with a range of [1,1024] Batch Size. 1. Generate a candidate list with a range of [1,1024] Batch Size.
``` ```
# Batch Size [1, 2, 4, 8, 16, ... 120, 128] # Batch Size [1, 2, 4, 8, 16, ... 120, 128]
candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)] candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
@@ -100,24 +99,25 @@ When CudaGraph is enabled in the default configuration, a list of Batch Sizes th
# Batch Size (256, 288, ... 992, 1024] # Batch Size (256, 288, ... 992, 1024]
candidate_capture_sizes += [32 * i for i in range(17, 33)] candidate_capture_sizes += [32 * i for i in range(17, 33)]
``` ```
2. Crop the candidate list based on the user set 'max_num_deqs' to obtain a CudaGraph capture list with a range of [1,' max_num_deqs']. 2. Crop the candidate list based on the user set 'max_num_deqs' to obtain a CudaGraph capture list with a range of [1,' max_num_deqs'].
Users can also customize the batch size list that needs to be captured by CudaGraph through the parameter `cudagraph_capture_sizes` in`--graph-optimization-config`: Users can also customize the batch size list that needs to be captured by CudaGraph through the parameter `cudagraph_capture_sizes` in`--graph-optimization-config`:
``` ```
--graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}' --graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}'
``` ```
### CudaGraph related parameters ### CudaGraph related parameters
Using CudaGraph incurs some additional memory overhead, divided into two categories in FastDeploy: Using CudaGraph incurs some additional memory overhead, divided into two categories in FastDeploy:
* Additional input Buffer overhead - Additional input Buffer overhead
* CudaGraph uses dedicated memory pool, thus holding some intermediate activation memory isolated from main framework - CudaGraph uses dedicated memory pool, thus holding some intermediate activation memory isolated from main framework
FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter to calculate available memory for `KVCache`, after initializing `KVCache` then uses remaining memory to initialize CudaGraph. Since CudaGraph is not enabled by default currently, using default startup parameters may encounter `Out of memory` errors, can try following solutions: FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter to calculate available memory for `KVCache`, after initializing `KVCache` then uses remaining memory to initialize CudaGraph. Since CudaGraph is not enabled by default currently, using default startup parameters may encounter `Out of memory` errors, can try following solutions:
* Lower `gpu_memory_utilization` value, reserve more memory for CudaGraph. - Lower `gpu_memory_utilization` value, reserve more memory for CudaGraph.
* Lower `max_num_seqs` to decrease the maximum concurrency. - Lower `max_num_seqs` to decrease the maximum concurrency.
* Customize the batch size list that CudaGraph needs to capture through `graph_optimization_config`, and reduce the number of captured graphs by using `cudagraph_capture_sizes` - Customize the batch size list that CudaGraph needs to capture through `graph_optimization_config`, and reduce the number of captured graphs by using `cudagraph_capture_sizes`
- Before use, must ensure loaded model is properly decorated with ```@support_graph_optimization```. - Before use, must ensure loaded model is properly decorated with ```@support_graph_optimization```.
@@ -148,5 +148,6 @@ FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter
class Ernie45TModel(nn.Layer): # Note decorator is added to nn.Layer subclass class Ernie45TModel(nn.Layer): # Note decorator is added to nn.Layer subclass
... ...
``` ```
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1. - When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```. - When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.

View File

@@ -25,13 +25,10 @@
多实例情况下每收到一条请求需要根据不同的策略将请求分配到不同的Prefill实例和Decode实例。通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。 多实例情况下每收到一条请求需要根据不同的策略将请求分配到不同的Prefill实例和Decode实例。通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。
## 使用说明 ## 使用说明
### 单机分离式部署 ### 单机分离式部署
#### 在线推理服务 #### 在线推理服务
使用如下命令进行服务部署 使用如下命令进行服务部署
@@ -75,9 +72,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
### 多机分离式部署 ### 多机分离式部署
#### 前置依赖 Redis #### 前置依赖 Redis
- 使用`conda`安装 * 使用`conda`安装
```bash ```bash
# 安装 # 安装
conda install redis conda install redis
@@ -85,7 +82,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 & nohup redis-server > redis.log 2>&1 &
``` ```
- 使用`apt`安装 * 使用`apt`安装
```bash ```bash
# 安装 # 安装
sudo apt install redis-server -y sudo apt install redis-server -y
@@ -93,7 +91,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server sudo systemctl start redis-server
``` ```
- 使用`yum`安装 * 使用`yum`安装
```bash ```bash
# 安装 # 安装
sudo yum install redis -y sudo yum install redis -y

View File

@@ -23,6 +23,7 @@
### 前置依赖 Redis ### 前置依赖 Redis
- 使用`conda`安装 - 使用`conda`安装
```bash ```bash
# 安装 # 安装
conda install redis conda install redis
@@ -31,6 +32,7 @@ nohup redis-server > redis.log 2>&1 &
``` ```
- 使用`apt`安装 - 使用`apt`安装
```bash ```bash
# 安装 # 安装
sudo apt install redis-server -y sudo apt install redis-server -y
@@ -39,6 +41,7 @@ sudo systemctl start redis-server
``` ```
- 使用`yum`安装 - 使用`yum`安装
```bash ```bash
# 安装 # 安装
sudo yum install redis -y sudo yum install redis -y
@@ -47,6 +50,7 @@ sudo systemctl start redis
``` ```
### 启动FastDeploy ### 启动FastDeploy
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \ --port 8801 \
@@ -62,6 +66,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \ --scheduler-min-load_score 3 \
--scheduler-load-shards-num 1 --scheduler-load-shards-num 1
``` ```
[启动参数说明](../online_serving/scheduler.md) [启动参数说明](../online_serving/scheduler.md)
可以将上述启动命令在多个机器执行,启动多个推理实例(如果是在一个机器中启动多个推理实例,注意端口不要冲突)。 可以将上述启动命令在多个机器执行,启动多个推理实例(如果是在一个机器中启动多个推理实例,注意端口不要冲突)。

View File

@@ -8,7 +8,6 @@ Prefix Caching前缀缓存是一种优化生成式模型推理效率的技
增量计算:对于后续请求,只需计算新增部分(如用户追加的输入)并复用缓存的中间结果,显著减少计算量。 增量计算:对于后续请求,只需计算新增部分(如用户追加的输入)并复用缓存的中间结果,显著减少计算量。
## 服务化部署开启 Prefix Caching ## 服务化部署开启 Prefix Caching
启动服务增加以下参数 `enable-prefix-caching`默认只开启一级缓存GPU 缓存)。 启动服务增加以下参数 `enable-prefix-caching`默认只开启一级缓存GPU 缓存)。

View File

@@ -17,10 +17,10 @@
同时在思考模型中,支持通过```reasoning_max_tokens```控制思考内容的长度,在请求中添加```metadata={"reasoning_max_tokens": 1024}```即可。 同时在思考模型中,支持通过```reasoning_max_tokens```控制思考内容的长度,在请求中添加```metadata={"reasoning_max_tokens": 1024}```即可。
## 快速使用
### 快速使用
在启动模型服务时, 通过`--reasoning-parser`参数指定解析器名称. 在启动模型服务时, 通过`--reasoning-parser`参数指定解析器名称.
该解析器会解析思考模型的输出, 提取`reasoning_content`字段. 该解析器会解析思考模型的输出, 提取`reasoning_content`字段.
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \ --model /path/to/your/model \
@@ -30,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \ --quantization wint4 \
--reasoning-parser ernie-45-vl --reasoning-parser ernie-45-vl
``` ```
接下来, 向模型发送 `chat completion` 请求 接下来, 向模型发送 `chat completion` 请求
```bash ```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@@ -45,10 +47,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
}' }'
``` ```
字段`reasoning_content`包含得出最终结论的思考步骤,而`content`字段包含最终结论。 字段`reasoning_content`包含得出最终结论的思考步骤,而`content`字段包含最终结论。
### 流式会话 ### 流式会话
在流式会话中, `reasoning_content`字段会可以在`chat completion response chunks`中的 `delta` 中获取 在流式会话中, `reasoning_content`字段会可以在`chat completion response chunks`中的 `delta` 中获取
```python ```python
from openai import OpenAI from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server. # Set OpenAI's API key and API base to use vLLM's API server.
@@ -73,4 +77,3 @@ for chunk in chat_response:
print("\n") print("\n")
``` ```

View File

@@ -50,6 +50,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}'
``` ```
### PD 分离式部署1P1D ### PD 分离式部署1P1D
> 在8×H100上部署1P1DP、D节点 分别使用 4×H100量化方式选择 WINT4 > 在8×H100上部署1P1DP、D节点 分别使用 4×H100量化方式选择 WINT4
> 与常规 PD 分离部署一致,仅需替换配置文件并新增 speculative_config > 与常规 PD 分离部署一致,仅需替换配置文件并新增 speculative_config
@@ -57,6 +58,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
- P 节点Prefill - P 节点Prefill
> 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-prefill.yaml` > 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-prefill.yaml`
``` ```
export FD_LOG_DIR="log_prefill" export FD_LOG_DIR="log_prefill"
rm -rf ${FD_LOG_DIR} rm -rf ${FD_LOG_DIR}
@@ -80,9 +82,11 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-password "scheduler_mtp" \ --scheduler-password "scheduler_mtp" \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' & --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' &
``` ```
- D 节点Decode - D 节点Decode
> 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-decode.yaml` > 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-decode.yaml`
``` ```
export FD_LOG_DIR="log_prefill" export FD_LOG_DIR="log_prefill"
rm -rf ${FD_LOG_DIR} rm -rf ${FD_LOG_DIR}
@@ -111,6 +115,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。 该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。
> 使用 4×H100量化方式选择 WINT4 > 使用 4×H100量化方式选择 WINT4
> 配置文件benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml > 配置文件benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
``` ```
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \ --model ${path_to_main_model} \

View File

@@ -131,4 +131,3 @@ python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parall
```json ```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}} {"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
``` ```

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
``` ```
## 2. 启动服务 ## 2. 启动服务
```bash ```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN" export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8 --gpu-memory-utilization=0.8
``` ```
#### 请求服务 ### 请求服务
您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。 您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。

View File

@@ -18,13 +18,15 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
## 准备容器 ## 准备容器
1. 启动容器 1. 启动容器
```bash ```bash
docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker exec -it paddle_infer bash docker exec -it paddle_infer bash
``` ```
/home/paddle 为模型文件、whl包、脚本所在目录 /home/paddle 为模型文件、whl包、脚本所在目录
2. 安装whl包 1. 安装whl包
```bash ```bash
pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
@@ -37,6 +39,7 @@ pip3 install fastdeploy_iluvatar_gpu -i https://www.paddlepaddle.org.cn/packages
脚本内容如下 脚本内容如下
`run_demo.sh`: `run_demo.sh`:
```bash ```bash
#!/bin/bash #!/bin/bash
export PADDLE_XCCL_BACKEND=iluvatar_gpu export PADDLE_XCCL_BACKEND=iluvatar_gpu
@@ -48,7 +51,6 @@ python3 run_demo.py
run_demo.py run_demo.py
```python ```python
from fastdeploy import LLM, SamplingParams from fastdeploy import LLM, SamplingParams
@@ -75,10 +77,13 @@ for output in outputs:
## 运行demo ## 运行demo
执行 执行
```bash ```bash
./run_demo.sh ./run_demo.sh
``` ```
会有如下 log 打印load 模型耗时约74sdemo 运行约240s。 会有如下 log 打印load 模型耗时约74sdemo 运行约240s。
``` ```
/usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md /usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
warnings.warn(warning_message) warnings.warn(warning_message)

View File

@@ -21,6 +21,7 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12
## 2. 预编译Pip安装 ## 2. 预编译Pip安装
首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) 首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
``` shell ``` shell
python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
``` ```
@@ -28,6 +29,7 @@ python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn
再安装 fastdeploy**注意不要通过pypi源安装**,需要通过如下方式安装 再安装 fastdeploy**注意不要通过pypi源安装**,需要通过如下方式安装
如你的 GPU 是 SM80/90 架构(A100/H100等),按如下方式安装 如你的 GPU 是 SM80/90 架构(A100/H100等),按如下方式安装
``` ```
# 安装稳定版本fastdeploy # 安装稳定版本fastdeploy
python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
@@ -37,6 +39,7 @@ python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages
``` ```
如你的 GPU 是 SM86/89 架构(4090/L20/L40等),按如下方式安装 如你的 GPU 是 SM86/89 架构(4090/L20/L40等),按如下方式安装
``` ```
# 安装稳定版本fastdeploy # 安装稳定版本fastdeploy
python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
@@ -59,11 +62,13 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu .
## 4. Wheel包源码编译 ## 4. Wheel包源码编译
首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/) 首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/)
``` shell ``` shell
python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
``` ```
接着克隆源代码,编译安装 接着克隆源代码,编译安装
``` shell ``` shell
git clone https://github.com/PaddlePaddle/FastDeploy git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy cd FastDeploy
@@ -74,11 +79,13 @@ cd FastDeploy
# 第4个参数: 编译的GPU架构 # 第4个参数: 编译的GPU架构
bash build.sh 1 python false [80,90] bash build.sh 1 python false [80,90]
``` ```
编译后的产物在```FastDeploy/dist```目录下。 编译后的产物在```FastDeploy/dist```目录下。
## 环境检查 ## 环境检查
在安装 FastDeploy 后,通过如下 Python 代码检查环境的可用性 在安装 FastDeploy 后,通过如下 Python 代码检查环境的可用性
``` python ``` python
import paddle import paddle
from paddle.jit.marker import unified from paddle.jit.marker import unified
@@ -87,4 +94,5 @@ paddle.utils.run_check()
# 检查FastDeploy自定义算子编译成功与否 # 检查FastDeploy自定义算子编译成功与否
from fastdeploy.model_executor.ops.gpu import beam_search_softmax from fastdeploy.model_executor.ops.gpu import beam_search_softmax
``` ```
如上代码执行成功,则认为环境可用。 如上代码执行成功,则认为环境可用。

View File

@@ -15,6 +15,7 @@
## 1. 启动服务 ## 1. 启动服务
安装FastDeploy后在终端执行如下命令启动服务其中启动命令配置方式参考[参数说明](../parameters.md) 安装FastDeploy后在终端执行如下命令启动服务其中启动命令配置方式参考[参数说明](../parameters.md)
```shell ```shell
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \ --model baidu/ERNIE-4.5-0.3B-Paddle \
@@ -24,6 +25,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--max-model-len 32768 \ --max-model-len 32768 \
--max-num-seqs 32 --max-num-seqs 32
``` ```
>💡 注意:在 ```--model``` 指定的路径中,若当前目录下不存在该路径对应的子目录,则会尝试根据指定的模型名称(如 ```baidu/ERNIE-4.5-0.3B-Paddle```查询AIStudio是否存在预置模型若存在则自动启动下载。默认的下载路径为```~/xx```。关于模型自动下载的说明和配置参阅[模型下载](../supported_models.md)。 >💡 注意:在 ```--model``` 指定的路径中,若当前目录下不存在该路径对应的子目录,则会尝试根据指定的模型名称(如 ```baidu/ERNIE-4.5-0.3B-Paddle```查询AIStudio是否存在预置模型若存在则自动启动下载。默认的下载路径为```~/xx```。关于模型自动下载的说明和配置参阅[模型下载](../supported_models.md)。
```--max-model-len``` 表示当前部署的服务所支持的最长Token数量。 ```--max-model-len``` 表示当前部署的服务所支持的最长Token数量。
```--max-num-seqs``` 表示当前部署的服务所支持的最大并发处理数量。 ```--max-num-seqs``` 表示当前部署的服务所支持的最大并发处理数量。
@@ -36,6 +38,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
## 2. 用户发起服务请求 ## 2. 用户发起服务请求
执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。 执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。
``` ```
api_server.py[line:91] Launching metrics service at http://0.0.0.0:8181/metrics api_server.py[line:91] Launching metrics service at http://0.0.0.0:8181/metrics
api_server.py[line:94] Launching chat completion service at http://0.0.0.0:8180/v1/chat/completions api_server.py[line:94] Launching chat completion service at http://0.0.0.0:8180/v1/chat/completions
@@ -47,11 +50,13 @@ INFO: Uvicorn running on http://0.0.0.0:8180 (Press CTRL+C to quit)
``` ```
FastDeploy提供服务探活接口用以判断服务的启动状态执行如下命令返回 ```HTTP/1.1 200 OK``` 即表示服务启动成功。 FastDeploy提供服务探活接口用以判断服务的启动状态执行如下命令返回 ```HTTP/1.1 200 OK``` 即表示服务启动成功。
```shell ```shell
curl -i http://0.0.0.0:8180/health curl -i http://0.0.0.0:8180/health
``` ```
通过如下命令发起服务请求 通过如下命令发起服务请求
```shell ```shell
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \ curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \

View File

@@ -24,6 +24,7 @@
## 文档说明 ## 文档说明
本项目文档基于mkdocs支持编译可视化查看参考如下命令进行编译预览 本项目文档基于mkdocs支持编译可视化查看参考如下命令进行编译预览
``` ```
pip install requirements.txt pip install requirements.txt
@@ -32,4 +33,5 @@ mkdocs build
mkdocs serve mkdocs serve
``` ```
根据提示打开相应地址即可。 根据提示打开相应地址即可。

View File

@@ -19,7 +19,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
--enable-logprob --enable-logprob
``` ```
服务部署时的命令行更多使用方式参考[参数说明](../parameters.md)。 服务部署时的命令行更多使用方式参考[参数说明](../parameters.md)。
## 发送用户请求 ## 发送用户请求
@@ -51,6 +50,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
``` ```
使用 Python 脚本发送用户请求示例如下: 使用 Python 脚本发送用户请求示例如下:
```python ```python
import openai import openai
host = "0.0.0.0" host = "0.0.0.0"
@@ -103,6 +103,7 @@ FastDeploy 增加的返回字段如下:
- `reasoning_content`: 思考链的返回结果 - `reasoning_content`: 思考链的返回结果
返回参数总览 返回参数总览
```python ```python
ChatCompletionStreamResponse: ChatCompletionStreamResponse:
id: str id: str

View File

@@ -18,7 +18,6 @@ FastDeploy 目前支持两种调度器: **本地调度器** 和 **全局调度
通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。 通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。
## 配置参数 ## 配置参数
| 字段名 | 字段类型 | 是否必填 | 默认值 | 生效范围 | 说明 | | 字段名 | 字段类型 | 是否必填 | 默认值 | 生效范围 | 说明 |
| ------------------------------------ | -------- | -------- | --------- |------------------------|-----------------------------------| | ------------------------------------ | -------- | -------- | --------- |------------------------|-----------------------------------|

View File

@@ -2,7 +2,6 @@
在使用FastDeploy部署模型包括离线推理、服务化部署涉及如下参数配置其实需要注意在使用离线推理时各参数配置即为如下参数名而在使用命令行启动服务时相应参数中的分隔符需要从```_```修改为```-```,如```max_model_len```在命令行中则为```--max-model-len```。 在使用FastDeploy部署模型包括离线推理、服务化部署涉及如下参数配置其实需要注意在使用离线推理时各参数配置即为如下参数名而在使用命令行启动服务时相应参数中的分隔符需要从```_```修改为```-```,如```max_model_len```在命令行中则为```--max-model-len```。
| 参数名 | 类型 | 说明 | | 参数名 | 类型 | 说明 |
|:-----------------------------------|:----------| :----- | |:-----------------------------------|:----------| :----- |
| ```port``` | `int` | 仅服务化部署需配置服务HTTP请求端口号默认8000 | | ```port``` | `int` | 仅服务化部署需配置服务HTTP请求端口号默认8000 |
@@ -44,7 +43,6 @@
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 | | ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 | | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache块```和```模型计算中间激活值```占用。其中预分配KVCache块由```num_gpu_blocks_override```决定,其单位为```block_size```(默认64即一个块可以存储64个Token的KVCache。 FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache块```和```模型计算中间激活值```占用。其中预分配KVCache块由```num_gpu_blocks_override```决定,其单位为```block_size```(默认64即一个块可以存储64个Token的KVCache。
@@ -88,6 +86,7 @@ FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache
在默认配置下开启 CudaGraph 时,会根据 `max_num_seqs` 参数自动设置 CudaGraph 需要捕获的 Batch Size 列表,需要捕获的 Batch Size 的列表自动生成逻辑如下: 在默认配置下开启 CudaGraph 时,会根据 `max_num_seqs` 参数自动设置 CudaGraph 需要捕获的 Batch Size 列表,需要捕获的 Batch Size 的列表自动生成逻辑如下:
1. 生成一个范围为 [1,1024] Batch Size 的候选列表 1. 生成一个范围为 [1,1024] Batch Size 的候选列表
``` ```
# Batch Size [1, 2, 4, 8, 16, ... 120, 128] # Batch Size [1, 2, 4, 8, 16, ... 120, 128]
candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)] candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
@@ -96,24 +95,24 @@ FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache
# Batch Size (256, 288, ... 992, 1024] # Batch Size (256, 288, ... 992, 1024]
candidate_capture_sizes += [32 * i for i in range(17, 33)] candidate_capture_sizes += [32 * i for i in range(17, 33)]
``` ```
2. 根据用户设置的 `max_num_seqs` 裁剪候选列表,得到范围为 [1, `max_num_seqs`] 的 CudaGraph 捕获列表。 2. 根据用户设置的 `max_num_seqs` 裁剪候选列表,得到范围为 [1, `max_num_seqs`] 的 CudaGraph 捕获列表。
用户也可以通过 `--graph-optimization-config` 中的 `cudagraph_capture_sizes` 参数自定义需要被 CudaGraph 捕获的 Batch Size 列表: 用户也可以通过 `--graph-optimization-config` 中的 `cudagraph_capture_sizes` 参数自定义需要被 CudaGraph 捕获的 Batch Size 列表:
``` ```
--graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}' --graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}'
``` ```
### CudaGraph相关参数说明 ### CudaGraph相关参数说明
使用 CudaGraph 会产生一些额外的显存开销在FastDeploy中分为下面两类 使用 CudaGraph 会产生一些额外的显存开销在FastDeploy中分为下面两类
* 额外的输入 Buffer 开销 - 额外的输入 Buffer 开销
* CudaGraph 使用了专用的显存池,因此会持有一部分与主框架隔离的中间激活显存 - CudaGraph 使用了专用的显存池,因此会持有一部分与主框架隔离的中间激活显存
FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算 `KVCache` 可用的显存,初始化完 `KVCache` 之后才会使用剩余显存初始化 CudaGraph。由于 CudaGraph 目前还不是默认开启的,因此使用默认启动参数可能会遇到 `Out Of Memory` 错误,可以尝试使用下面三种方式解决: FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算 `KVCache` 可用的显存,初始化完 `KVCache` 之后才会使用剩余显存初始化 CudaGraph。由于 CudaGraph 目前还不是默认开启的,因此使用默认启动参数可能会遇到 `Out Of Memory` 错误,可以尝试使用下面三种方式解决:
* 调低 `gpu_memory_utilization` 的值多预留一些显存给CudaGraph使用。 - 调低 `gpu_memory_utilization` 的值多预留一些显存给CudaGraph使用。
* 调低 `max_num_seqs` 的值,降低最大并发数。 - 调低 `max_num_seqs` 的值,降低最大并发数。
* 通过 `graph_optimization_config` 自定义需要 CudaGraph 捕获的 Batch Size 列表 `cudagraph_capture_sizes`,减少捕获的图的数量 - 通过 `graph_optimization_config` 自定义需要 CudaGraph 捕获的 Batch Size 列表 `cudagraph_capture_sizes`,减少捕获的图的数量
使用CudaGraph之前需要确保加载的模型被装饰器 ```@support_graph_optimization```正确修饰。 使用CudaGraph之前需要确保加载的模型被装饰器 ```@support_graph_optimization```正确修饰。
@@ -144,5 +143,6 @@ FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算
class Ernie45TModel(nn.Layer): # 注意 decorator 加在 nn.Layer 的子类上 class Ernie45TModel(nn.Layer): # 注意 decorator 加在 nn.Layer 的子类上
... ...
``` ```
- 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。 - 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。 - 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。

View File

@@ -44,4 +44,3 @@ FastDeploy 按以下格式命名各种量化精度:
- **WNF4A8C8**NF4指4bit norm-float数值类型 - **WNF4A8C8**NF4指4bit norm-float数值类型
- **Wfp8Afp8**权重和激活均为FP8精度 - **Wfp8Afp8**权重和激活均为FP8精度
- **W4Afp8**权重为INT4, 激活为FP8 - **W4Afp8**权重为INT4, 激活为FP8

View File

@@ -52,6 +52,3 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 通过设置 `--quantization``block_wise_fp8` 选择在线 Block-wise FP8 量化。 - 通过设置 `--quantization``block_wise_fp8` 选择在线 Block-wise FP8 量化。
- 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。 - 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md) - 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)

View File

@@ -48,7 +48,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md) - 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)
- 更多模型说明请参考[支持模型列表](../supported_models.md)。 - 更多模型说明请参考[支持模型列表](../supported_models.md)。
## WINT2效果 ## WINT2效果
在ERNIE-4.5-300B-A47B模型上WINT2与WINT4效果对比 在ERNIE-4.5-300B-A47B模型上WINT2与WINT4效果对比

View File

@@ -22,4 +22,3 @@
- ```splitwise```: 分离式部署相关模块 - ```splitwise```: 分离式部署相关模块
- ```scripts```/```tools```FastDeploy 用于执行功能的辅助脚本,比如编译,单测执行,代码风格纠正等 - ```scripts```/```tools```FastDeploy 用于执行功能的辅助脚本,比如编译,单测执行,代码风格纠正等
- ```test```:项目单测验证使用到的代码 - ```test```:项目单测验证使用到的代码

View File

@@ -19,14 +19,12 @@ FastDeploy 在部署过程中,会产生如下日志文件,各日志含义说
## 在线推理客户端日志 ## 在线推理客户端日志
* `api_server.log` : 记录启动参数,及接收到的请求信息 * `api_server.log` : 记录启动参数,及接收到的请求信息
## 调度器日志 ## 调度器日志
* `scheduler.log` : 记录调度器的信息包含当前结点的信息,每条请求分配的信息 * `scheduler.log` : 记录调度器的信息包含当前结点的信息,每条请求分配的信息
## 投机解码日志 ## 投机解码日志
* `speculate.log` : 投机解码相关信息 * `speculate.log` : 投机解码相关信息
## Prefix Caching 相关日志 ## Prefix Caching 相关日志
* `cache_queue_manager.log` : 记录启动参数,及接收到的请求信息 * `cache_queue_manager.log` : 记录启动参数,及接收到的请求信息

View File

@@ -22,14 +22,14 @@ import sys
os.environ["GLOG_minloglevel"] = "2" os.environ["GLOG_minloglevel"] = "2"
# suppress log from aistudio # suppress log from aistudio
os.environ["AISTUDIO_LOG"] = "critical" os.environ["AISTUDIO_LOG"] = "critical"
from fastdeploy.utils import version
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM from fastdeploy.entrypoints.llm import LLM
__all__ = ['LLM', 'SamplingParams'] __all__ = ["LLM", "SamplingParams"]
try: try:
import use_triton_in_paddle import use_triton_in_paddle
use_triton_in_paddle.make_triton_compatible_with_paddle() use_triton_in_paddle.make_triton_compatible_with_paddle()
except ImportError: except ImportError:
pass pass
@@ -38,13 +38,21 @@ except ImportError:
def _patch_fastsafetensors(): def _patch_fastsafetensors():
try: try:
file_path = subprocess.check_output([ file_path = (
sys.executable, "-c", "import fastsafetensors, os; \ subprocess.check_output(
[
sys.executable,
"-c",
"import fastsafetensors, os; \
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \ print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
'frameworks', '_paddle.py'))" 'frameworks', '_paddle.py'))",
]).decode().strip() ]
)
.decode()
.strip()
)
with open(file_path, 'r') as f: with open(file_path, "r") as f:
content = f.read() content = f.read()
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content: if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
return return
@@ -56,21 +64,20 @@ def _patch_fastsafetensors():
inside_block = False inside_block = False
for line in lines: for line in lines:
new_lines.append(line) new_lines.append(line)
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line: if "need_workaround_dtypes: Dict[DType, DType] = {" in line:
inside_block = True inside_block = True
elif inside_block and '}' in line: elif inside_block and "}" in line:
new_lines.insert(-1, ' DType.U16: DType.BF16,') new_lines.insert(-1, " DType.U16: DType.BF16,")
inside_block = False inside_block = False
modified = True modified = True
content = "\n".join(new_lines) content = "\n".join(new_lines)
if "DType.I8: paddle.uint8," in content: if "DType.I8: paddle.uint8," in content:
content = content.replace("DType.I8: paddle.uint8,", content = content.replace("DType.I8: paddle.uint8,", "DType.U8: paddle.uint8,")
"DType.U8: paddle.uint8,")
modified = True modified = True
if modified: if modified:
with open(file_path, 'w') as f: with open(file_path, "w") as f:
f.write(content + "\n") f.write(content + "\n")
except Exception as e: except Exception as e:

View File

@@ -109,13 +109,12 @@ class BlockNode:
parent_node_id = None parent_node_id = None
return ( return (
f"node_id {self.node_id}: depth {self.depth} hash_value {self.hash_value}" f"node_id {self.node_id}: depth {self.depth} hash_value {self.hash_value}"
+ + f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}"
f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}" + f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} "
+ + f"has_in_gpu {self.has_in_gpu} "
f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} " + f"cache_status {self.cache_status} parent {parent_node_id} with children number "
+ f"has_in_gpu {self.has_in_gpu} " + + f"{len(self.children)} req_id_set {self.req_id_set}"
f"cache_status {self.cache_status} parent {parent_node_id} with children number " )
+ f"{len(self.children)} req_id_set {self.req_id_set}")
@property @property
def has_in_gpu(self): def has_in_gpu(self):
@@ -141,8 +140,7 @@ class BlockNode:
""" """
check if the node is a leaf node in CPU check if the node is a leaf node in CPU
""" """
if (self.cache_status == CacheStatus.CPU) and (len(self.children) if (self.cache_status == CacheStatus.CPU) and (len(self.children) == 0):
== 0):
return True return True
return False return False

View File

@@ -21,20 +21,20 @@ import time
import numpy as np import numpy as np
import paddle import paddle
from fastdeploy.cache_manager.transfer_factory import (IPCCommManager, from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
RDMACommManager)
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log") logger = get_logger("cache_messager", "cache_messager.log")
class CacheMessager(object): class CacheMessager:
""" """
CacheMessager is used to send the cache data between the engine worker and the cache server. CacheMessager is used to send the cache data between the engine worker and the cache server.
""" """
def __init__(self, def __init__(
self,
splitwise_role, splitwise_role,
transfer_protocol, transfer_protocol,
pod_ip, pod_ip,
@@ -45,7 +45,8 @@ class CacheMessager(object):
nranks, nranks,
num_layers, num_layers,
gpu_id=0, gpu_id=0,
rdma_port=None): rdma_port=None,
):
""" """
Initialize the CacheMessager object. Initialize the CacheMessager object.
@@ -64,8 +65,10 @@ class CacheMessager(object):
None None
""" """
assert splitwise_role in ["prefill", "decode"], \ assert splitwise_role in [
"splitwise_role must be prefill or decode" "prefill",
"decode",
], "splitwise_role must be prefill or decode"
self.splitwise_role = splitwise_role self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank self.rank = rank
@@ -76,11 +79,11 @@ class CacheMessager(object):
is_server=False, is_server=False,
num_client=self.nranks, num_client=self.nranks,
client_id=self.rank, client_id=self.rank,
local_data_parallel_id=local_data_parallel_id) local_data_parallel_id=local_data_parallel_id,
)
transfer_protocol = transfer_protocol.split(",") transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list # 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers self.num_layers = num_layers
@@ -90,10 +93,8 @@ class CacheMessager(object):
cache_v = [] cache_v = []
self.messager = {} self.messager = {}
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[ key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
f'key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}'] val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[
f'value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
cache_k.append(key_cache) cache_k.append(key_cache)
cache_v.append(val_cache) cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr()) cache_k_ptr_list.append(key_cache.data_ptr())
@@ -109,7 +110,8 @@ class CacheMessager(object):
block_bytes *= 2 block_bytes *= 2
logger.info( logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}") f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
)
self.block_bytes = block_bytes self.block_bytes = block_bytes
# 3. initialize the messager # 3. initialize the messager
@@ -122,24 +124,26 @@ class CacheMessager(object):
cache_v, cache_v,
) )
local_device_id = int(str(cache_k[0].place)[-2]) local_device_id = int(str(cache_k[0].place)[-2])
logger.info( logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
f"done create ipc_comm with local_device_id:{local_device_id}, "
)
elif protocol == "rdma": elif protocol == "rdma":
logger.info( logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}"
)
self.messager[protocol] = RDMACommManager( self.messager[protocol] = RDMACommManager(
splitwise_role, rank, gpu_id, cache_k_ptr_list, splitwise_role,
cache_v_ptr_list, max_block_num, block_bytes, rdma_port) rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
)
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.cache_info = dict() self.cache_info = dict()
layerwise_send_cache_thread = threading.Thread( layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start() layerwise_send_cache_thread.start()
@@ -159,26 +163,30 @@ class CacheMessager(object):
array=prefilled_step_idx_data, array=prefilled_step_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=True) create=True,
)
layer_shm_value = IPCSignal( layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}", name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data, array=prefilled_layer_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=True) create=True,
)
except: except:
step_shm_value = IPCSignal( step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.rank}", name=f"splitwise_complete_prefilled_step_{self.rank}",
array=prefilled_step_idx_data, array=prefilled_step_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=False) create=False,
)
layer_shm_value = IPCSignal( layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}", name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data, array=prefilled_layer_idx_data,
dtype=np.int32, dtype=np.int32,
suffix=self.gpu_id, suffix=self.gpu_id,
create=False) create=False,
)
step_shm_value.value[0] = -1 step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1 layer_shm_value.value[0] = -1
@@ -193,21 +201,19 @@ class CacheMessager(object):
if cache_info: if cache_info:
logger.debug(f"cache info {cache_info}") logger.debug(f"cache info {cache_info}")
for info in cache_info: for info in cache_info:
if info['request_id'] in self.cache_info: if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info) self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]] current_info = self.cache_info[info["request_id"]]
if "dest_block_ids" in current_info and "src_block_ids" in current_info: if "dest_block_ids" in current_info and "src_block_ids" in current_info:
current_src_blocks = current_info[ current_src_blocks = current_info["src_block_ids"][
"src_block_ids"][-len(current_info["dest_block_ids"]):] -len(current_info["dest_block_ids"]) :
current_info[ ]
"src_block_ids"] = current_src_blocks current_info["src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0 current_info["current_layer_ids"] = 0
current_info["status"] = "init" current_info["status"] = "init"
logger.info( logger.info(f"start cache_infos: {current_info}")
f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min( self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
self.last_step_idx, current_info['current_id'])
else: else:
self.cache_info[info["request_id"]] = info self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0] prefilled_layer_idx = layer_shm_value.value[0]
@@ -223,64 +229,53 @@ class CacheMessager(object):
if not self.cache_info: if not self.cache_info:
time.sleep(0.001) time.sleep(0.001)
continue continue
logger.debug( logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}"
)
for req_id, item in list(self.cache_info.items()): for req_id, item in list(self.cache_info.items()):
if "status" not in item: if "status" not in item:
continue continue
if "layer_idx" not in item: if "layer_idx" not in item:
item["layer_idx"] = 0 item["layer_idx"] = 0
if item['status'] == 'error': if item["status"] == "error":
del self.cache_info[req_id] del self.cache_info[req_id]
continue continue
if item['current_id'] > prefilled_step_idx: if item["current_id"] > prefilled_step_idx:
continue continue
current_transfer_protocol = item["transfer_protocol"] current_transfer_protocol = item["transfer_protocol"]
if item["transfer_protocol"] == "rdma": if item["transfer_protocol"] == "rdma":
target_ip = item['ip'] target_ip = item["ip"]
target_id = int(item['rdma_ports'][self.rank]) target_id = int(item["rdma_ports"][self.rank])
status = self.messager[ status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
current_transfer_protocol].connect(
target_ip, target_id)
if not status: if not status:
logger.error( logger.error(f"connect to {target_ip}:{target_id} failed")
f"connect to {target_ip}:{target_id} failed")
item["status"] = "error" item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
self.engine_worker_queue.put_finished_req([ self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
(item['request_id'], "connect error")
])
continue continue
elif item["transfer_protocol"] == "ipc": elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0" target_ip = "0.0.0.0"
target_id = int(item['device_ids'][self.rank]) target_id = int(item["device_ids"][self.rank])
src_block_ids = paddle.to_tensor(item['src_block_ids'], src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dtype='int32', dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
place='cpu') if item["current_id"] < prefilled_step_idx:
dest_block_ids = paddle.to_tensor(item['dest_block_ids'],
dtype='int32',
place='cpu')
if item['current_id'] < prefilled_step_idx:
current_layer_idx = self.num_layers current_layer_idx = self.num_layers
else: else:
current_layer_idx = prefilled_layer_idx + 1 current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"], for layer_idx in range(item["layer_idx"], current_layer_idx):
current_layer_idx):
tic = time.time() tic = time.time()
return_code = self.messager[ return_code = self.messager[current_transfer_protocol].write_cache(
current_transfer_protocol].write_cache( target_ip,
target_ip, target_id, src_block_ids, target_id,
dest_block_ids, layer_idx) src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0: if return_code != 0:
item["status"] = "error" item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
self.engine_worker_queue.put_finished_req([ self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
(item['request_id'], "write cache error")
])
logger.error( logger.error(
f"write cache failed, layer_idx: {layer_idx}, " f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}" f"req_id: {item['request_id']}, dest_ip: {target_ip}"
@@ -298,16 +293,14 @@ class CacheMessager(object):
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)}," f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}" f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
) )
item['layer_idx'] = current_layer_idx item["layer_idx"] = current_layer_idx
if item['layer_idx'] == self.num_layers: if item["layer_idx"] == self.num_layers:
if item["transfer_protocol"] == "ipc": if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id) self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}") logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait() self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0: if self.rank == 0:
self.engine_worker_queue.put_finished_req([ self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
(item['request_id'], "finished")
])
logger.info(f"put write cache {item['request_id']}") logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id] del self.cache_info[req_id]
@@ -315,5 +308,4 @@ class CacheMessager(object):
self.last_layer_idx = prefilled_layer_idx self.last_layer_idx = prefilled_layer_idx
except Exception as e: except Exception as e:
logger.error( logger.error(f"prefill layerwise send cache thread has exception: {e}")
f"prefill layerwise send cache thread has exception: {e}")

View File

@@ -14,18 +14,16 @@
# limitations under the License. # limitations under the License.
""" """
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
class CacheMetrics: class CacheMetrics:
""" """
Cache Metrics used to record the cache hit time, token num, request num, etc. Cache Metrics used to record the cache hit time, token num, request num, etc.
""" """
def __init__(self): def __init__(self):
self.total_match_time = 0.0 self.total_match_time = 0.0
self.avg_match_time = 0.0 self.avg_match_time = 0.0
@@ -47,19 +45,14 @@ class CacheMetrics:
self.cpu_hit_token_ratio = 0.0 self.cpu_hit_token_ratio = 0.0
self.gpu_hit_token_ratio = 0.0 self.gpu_hit_token_ratio = 0.0
def _update_history_hit_metrics(self): def _update_history_hit_metrics(self):
""" """
update hit ratio update hit ratio
""" """
self.hit_req_ratio = self.hit_req_count / self.req_count self.hit_req_ratio = self.hit_req_count / self.req_count
self.hit_token_ratio = self.matched_token_num / self.total_token_num self.hit_token_ratio = self.matched_token_num / self.total_token_num
self.cpu_hit_token_ratio = ( self.cpu_hit_token_ratio = self.total_cpu_matched_token_num / self.total_token_num
self.total_cpu_matched_token_num / self.total_token_num self.gpu_hit_token_ratio = self.total_gpu_matched_token_num / self.total_token_num
)
self.gpu_hit_token_ratio = (
self.total_gpu_matched_token_num / self.total_token_num
)
logger.info( logger.info(
f"Metrics for all requests: req_count {self.req_count} hit_req_count {self.hit_req_count}" f"Metrics for all requests: req_count {self.req_count} hit_req_count {self.hit_req_count}"
@@ -83,29 +76,15 @@ class CacheMetrics:
calculate hit metrics for current query calculate hit metrics for current query
""" """
cpu_cache_match_ratio = ( cpu_cache_match_ratio = current_query_cpu_match_token_num / current_query_token_num
current_query_cpu_match_token_num / current_query_token_num gpu_cache_match_ratio = current_query_gpu_match_token_num / current_query_token_num
)
gpu_cache_match_ratio = (
current_query_gpu_match_token_num / current_query_token_num
)
total_match_ratio = ( total_match_ratio = cpu_cache_match_ratio + gpu_cache_match_ratio
cpu_cache_match_ratio + gpu_cache_match_ratio
)
self.total_cpu_matched_token_num += current_query_cpu_match_token_num
self.total_gpu_matched_token_num += current_query_gpu_match_token_num
self.total_cpu_matched_token_num += ( self.matched_token_num += current_query_cpu_match_token_num + current_query_gpu_match_token_num
current_query_cpu_match_token_num
)
self.total_gpu_matched_token_num += (
current_query_gpu_match_token_num
)
self.matched_token_num += (
current_query_cpu_match_token_num
+ current_query_gpu_match_token_num
)
self.total_token_num += current_query_token_num self.total_token_num += current_query_token_num
logger.info( logger.info(
f"Metrics for req_id {req_id}: token_num {current_query_token_num}" f"Metrics for req_id {req_id}: token_num {current_query_token_num}"

View File

@@ -26,8 +26,11 @@ import paddle
from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.config import SpeculativeConfig from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (cuda_host_alloc, set_data_ipc, from fastdeploy.model_executor.ops.gpu import (
swap_cache_all_layers) cuda_host_alloc,
set_data_ipc,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -36,79 +39,58 @@ def parse_args():
从命令行解析参数 从命令行解析参数
""" """
parser = argparse.ArgumentParser("Cache transfer manager") parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument("--splitwise_role", parser.add_argument(
"--splitwise_role",
type=str, type=str,
default="mixed", default="mixed",
help="splitwise role, can be decode, prefill or mixed") help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank") parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
type=int, parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
default=1, parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
help="model num layers")
parser.add_argument("--head_dim",
type=int,
default=1,
help="model head dim")
parser.add_argument("--kv_num_head",
type=int,
default=1,
help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port") parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
type=int, parser.add_argument(
default=1, "--protocol",
help="number of model parallel")
parser.add_argument("--protocol",
type=str, type=str,
default="ipc", default="ipc",
help="cache transfer protocol, only surport ipc now") help="cache transfer protocol, only surport ipc now",
parser.add_argument("--enable_splitwise", )
type=int, parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
default=0, parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
help="enable splitwise ") parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument("--cache_queue_port", parser.add_argument(
"--engine_worker_queue_port",
type=int, type=int,
default=9923, default=9923,
help="cache queue port") help="engine worker queue port",
parser.add_argument("--pod_ip", )
type=str, parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
default="0.0.0.0",
help="pod ip")
parser.add_argument("--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port")
parser.add_argument("--engine_pid",
type=str,
default=None,
help="engine pid")
parser.add_argument("--num_gpu_blocks", parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
type=int, parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
default=1, parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
help="gpu cache block number") parser.add_argument(
parser.add_argument("--num_cpu_blocks", "--bytes_per_layer_per_block",
type=int,
default=4,
help="cpu cache block number")
parser.add_argument("--block_size",
type=int,
default=64,
help="cache block size(tokens)")
parser.add_argument("--bytes_per_layer_per_block",
type=int, type=int,
default=1024, default=1024,
help="per layer per block bytes") help="per layer per block bytes",
parser.add_argument("--cache_dtype", )
parser.add_argument(
"--cache_dtype",
type=str, type=str,
default="bfloat16", default="bfloat16",
choices=["uint8", "bfloat16"], choices=["uint8", "bfloat16"],
help="cache dtype") help="cache dtype",
parser.add_argument("--speculative_config", )
parser.add_argument(
"--speculative_config",
type=json.loads, type=json.loads,
default="{}", default="{}",
help="speculative config") help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0) parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
@@ -134,14 +116,10 @@ class CacheTransferManager:
self.gpu_cache_v_tensors = [] self.gpu_cache_v_tensors = []
self.speculative_config = SpeculativeConfig(**args.speculative_config) self.speculative_config = SpeculativeConfig(**args.speculative_config)
self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = \ self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
int(args.num_gpu_blocks * \
self.speculative_config.num_gpu_block_expand_ratio)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor( self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
max_workers=1) self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num self.n_ranks = args.mp_num
@@ -154,17 +132,16 @@ class CacheTransferManager:
is_server=False, is_server=False,
num_client=args.mp_num, num_client=args.mp_num,
client_id=rank, client_id=rank,
local_data_parallel_id=args.local_data_parallel_id) local_data_parallel_id=args.local_data_parallel_id,
)
self.num_cpu_blocks = args.num_cpu_blocks self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers): for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else \ num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
self.num_extra_layer_gpu_blocks
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format( self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
i, rank, device)] = paddle.full(
shape=[ shape=[
num_gpu_blocks, num_gpu_blocks,
args.kv_num_head, args.kv_num_head,
@@ -174,11 +151,8 @@ class CacheTransferManager:
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
self.gpu_cache_k_tensors.append( self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format( self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
i, rank, device)])
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)] = paddle.full(
shape=[ shape=[
num_gpu_blocks, num_gpu_blocks,
args.kv_num_head, args.kv_num_head,
@@ -188,47 +162,42 @@ class CacheTransferManager:
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
self.gpu_cache_v_tensors.append( self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)])
set_data_ipc( set_data_ipc(
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format( self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
i, rank, device)], f"key_caches_{i}_rank{rank}.device{device}",
"key_caches_{}_rank{}.device{}".format(i, rank, device)) )
set_data_ipc( set_data_ipc(
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format( self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
i, rank, device)], f"value_caches_{i}_rank{rank}.device{device}",
"value_caches_{}_rank{}.device{}".format(i, rank, device)) )
cache_kv_size_byte = sum( cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
[tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}") logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info( logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
paddle.set_device("cpu") paddle.set_device("cpu")
self.k_dst_ptrs = [] self.k_dst_ptrs = []
self.v_dst_ptrs = [] self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers): for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs["key_caches_{}_rank{}".format( self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
i, rank)] = cuda_host_alloc(args.num_cpu_blocks * args.num_cpu_blocks * args.bytes_per_layer_per_block
args.bytes_per_layer_per_block) )
self.k_dst_ptrs.append( self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
self.cpu_cache_kvs["key_caches_{}_rank{}".format(i, rank)]) self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
self.cpu_cache_kvs["value_caches_{}_rank{}".format( args.num_cpu_blocks * args.bytes_per_layer_per_block
i, rank)] = cuda_host_alloc(args.num_cpu_blocks * )
args.bytes_per_layer_per_block) self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
self.v_dst_ptrs.append(
self.cpu_cache_kvs["value_caches_{}_rank{}".format(i, rank)])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(name="cache_ready_signal", self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data, array=cache_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=args.engine_pid, suffix=args.engine_pid,
create=False) create=False,
)
self.cache_ready_signal.value[self.rank] = 1 self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}") paddle.set_device(f"gpu:{device}")
@@ -251,9 +220,7 @@ class CacheTransferManager:
rdma_port=args.rdma_port, rdma_port=args.rdma_port,
) )
logger.info("successfully create cache messager") logger.info("successfully create cache messager")
logger.info( logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal( self.cache_task_broadcast_signal = IPCSignal(
@@ -261,10 +228,17 @@ class CacheTransferManager:
array=cache_task_broadcast_data, array=cache_task_broadcast_data,
dtype=np.int32, dtype=np.int32,
suffix=args.engine_pid, suffix=args.engine_pid,
create=False) create=False,
)
def _do_swap_to_cpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id, def _do_swap_to_cpu_task(
event_type, transfer_task_id): self,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
):
""" """
swap cache GPU->CPU swap cache GPU->CPU
""" """
@@ -282,14 +256,17 @@ class CacheTransferManager:
if self.rank == 0: if self.rank == 0:
self.cache_task_queue.swap_to_cpu_barrier2.reset() self.cache_task_queue.swap_to_cpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result) self.cache_task_queue.put_transfer_done_signal(result)
logger.debug( logger.debug(f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
f"_do_swap_to_cpu_task: put_transfer_done_signal {result}") logger.info(f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
logger.info(
f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
def _do_swap_to_gpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id, def _do_swap_to_gpu_task(
event_type, transfer_task_id): self,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
):
""" """
swap cache CPU->GPU swap cache CPU->GPU
""" """
@@ -307,11 +284,8 @@ class CacheTransferManager:
if self.rank == 0: if self.rank == 0:
self.cache_task_queue.swap_to_gpu_barrier2.reset() self.cache_task_queue.swap_to_gpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result) self.cache_task_queue.put_transfer_done_signal(result)
logger.debug( logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
f"_do_swap_to_gpu_task: put_transfer_done_signal {result}") logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
logger.info(
f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
def do_data_transfer(self): def do_data_transfer(self):
""" """
@@ -327,8 +301,7 @@ class CacheTransferManager:
if self.rank == 0: if self.rank == 0:
self.cache_task_queue.barrier1.reset() self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1: if self.cache_task_broadcast_signal.value[0] == 1:
data, read_finish = self.cache_task_queue.get_transfer_task( data, read_finish = self.cache_task_queue.get_transfer_task()
)
logger.debug(f"transfer data: get_transfer_task {data}") logger.debug(f"transfer data: get_transfer_task {data}")
if read_finish: if read_finish:
self.cache_task_broadcast_signal.value[0] = 0 self.cache_task_broadcast_signal.value[0] = 0
@@ -386,8 +359,7 @@ class CacheTransferManager:
""" """
logger.debug( logger.debug(
f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}" f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}"
+ + f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
) )
start_time = time.time() start_time = time.time()
try: try:
@@ -446,8 +418,7 @@ class CacheTransferManager:
elasped_time = end_time - start_time elasped_time = end_time - start_time
logger.info( logger.info(
f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: " f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: "
+ + f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
) )
return ( return (
swap_node_ids, swap_node_ids,

View File

@@ -41,11 +41,13 @@ class PrefixCacheManager:
PrefixCacheManager is used to manage the prefix tree and the cache. PrefixCacheManager is used to manage the prefix tree and the cache.
""" """
def __init__(self, def __init__(
self,
config, config,
tensor_parallel_size, tensor_parallel_size,
splitwise_role="mixed", splitwise_role="mixed",
local_data_parallel_id=0): local_data_parallel_id=0,
):
""" """
initialize the PrefixCacheManager initialize the PrefixCacheManager
""" """
@@ -66,14 +68,12 @@ class PrefixCacheManager:
self.num_cpu_blocks = self.cache_config.num_cpu_blocks self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0: if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list( self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
range(self.num_cpu_blocks - 1, -1, -1))
else: else:
self.cpu_free_block_list = [] self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list) heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list) heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list( self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None) self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
@@ -90,7 +90,7 @@ class PrefixCacheManager:
self.task_swapping_event = {} self.task_swapping_event = {}
self.node_map = {} self.node_map = {}
self.req_leaf_map = ({}) # {request_id: leaf node} self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set) self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list) self.unfilled_req_block_map = defaultdict(list)
@@ -102,14 +102,18 @@ class PrefixCacheManager:
logger.info( logger.info(
f"num_gpu_blocks_server_owned {self.num_gpu_blocks} num_cpu_blocks " f"num_gpu_blocks_server_owned {self.num_gpu_blocks} num_cpu_blocks "
+ + f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
) )
def launch_cache_manager(
self,
def launch_cache_manager(self, cache_config, tensor_parallel_size, \ cache_config,
device_ids, pod_ip, engine_worker_queue_port, pid_suffix): tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
):
""" """
launch_cache_manager function used to initialize the cache manager. launch_cache_manager function used to initialize the cache manager.
""" """
@@ -120,70 +124,72 @@ class PrefixCacheManager:
array=broadcast_cache_task_flag_array, array=broadcast_cache_task_flag_array,
dtype=np.int32, dtype=np.int32,
suffix=pid_suffix, suffix=pid_suffix,
create=True) create=True,
)
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
address=(pod_ip, cache_config.cache_queue_port), address=(pod_ip, cache_config.cache_queue_port),
authkey=b'cache_queue_service', authkey=b"cache_queue_service",
is_server=False, is_server=False,
num_client=tensor_parallel_size, num_client=tensor_parallel_size,
client_id=0, client_id=0,
local_data_parallel_id=self.local_data_parallel_id) local_data_parallel_id=self.local_data_parallel_id,
)
current_dir_path = os.path.split(os.path.abspath(__file__))[0] current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_transfer_manager.py" filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename) py_path = os.path.join(current_dir_path, filename)
if (hasattr(cache_config.model_cfg, "num_key_value_heads") if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads") and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0): and int(cache_config.model_cfg.num_key_value_heads) > 0
kv_num_head = int(cache_config.model_cfg.num_key_value_heads ):
) // tensor_parallel_size kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else: else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
dtype=np.int32) self.cache_ready_signal = IPCSignal(
self.cache_ready_signal = IPCSignal(name="cache_ready_signal", name="cache_ready_signal",
array=cache_ready_signal_data, array=cache_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=pid_suffix, suffix=pid_suffix,
create=True) create=True,
)
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
cache_manager_processes = [] cache_manager_processes = []
for i in range(tensor_parallel_size): for i in range(tensor_parallel_size):
launch_cmd = ( launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7" "FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
f" {sys.executable} {py_path}" + + f" {sys.executable} {py_path}"
f" --device_id {int(device_ids[i])}" + f" --rank {i}" + + f" --device_id {int(device_ids[i])}"
f" --splitwise_role {self.splitwise_role}" + + f" --rank {i}"
f" --num_layers {cache_config.model_cfg.num_layers}" + + f" --splitwise_role {self.splitwise_role}"
f" --head_dim {cache_config.model_cfg.head_dim}" + + f" --num_layers {cache_config.model_cfg.num_layers}"
f" --kv_num_head {kv_num_head}" + + f" --head_dim {cache_config.model_cfg.head_dim}"
f" --mp_num {tensor_parallel_size}" + + f" --kv_num_head {kv_num_head}"
f" --cache_dtype {cache_config.cache_dtype}" + + f" --mp_num {tensor_parallel_size}"
f" --cache_queue_port {cache_config.cache_queue_port}" + + f" --cache_dtype {cache_config.cache_dtype}"
f" --enable_splitwise {int(self.enable_splitwise)}" + + f" --cache_queue_port {cache_config.cache_queue_port}"
f" --pod_ip {pod_ip}" + + f" --enable_splitwise {int(self.enable_splitwise)}"
f" --engine_worker_queue_port {engine_worker_queue_port}" + + f" --pod_ip {pod_ip}"
f" --num_gpu_blocks {cache_config.total_block_num}" + + f" --engine_worker_queue_port {engine_worker_queue_port}"
f" --num_cpu_blocks {cache_config.num_cpu_blocks}" + + f" --num_gpu_blocks {cache_config.total_block_num}"
f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}" + f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --block_size {cache_config.block_size}" + + f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
f" --engine_pid {pid_suffix}" + + f" --block_size {cache_config.block_size}"
f" --protocol {cache_config.cache_transfer_protocol}" + + f" --engine_pid {pid_suffix}"
f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --protocol {cache_config.cache_transfer_protocol}"
f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --local_data_parallel_id {self.local_data_parallel_id}"
+ + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
) )
logger.info(f"Launch cache transfer manager, command:{launch_cmd}") logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append( cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕 # 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...") logger.info("Waiting for cache transfer manager ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size: while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
@@ -192,9 +198,7 @@ class PrefixCacheManager:
if exit_code is None: if exit_code is None:
logger.info("Launch cache transfer manager successful") logger.info("Launch cache transfer manager successful")
else: else:
logger.info( logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
"Launch cache transfer manager failed, see launch_cache_manager.log for more information"
)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0: if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.") logger.info("Enable hierarchical cache.")
@@ -207,12 +211,10 @@ class PrefixCacheManager:
""" """
self.cache_config = cache_config self.cache_config = cache_config
self.num_gpu_blocks = cache_config.prefill_kvcache_block_num self.num_gpu_blocks = cache_config.prefill_kvcache_block_num
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) # 服务端管理的GPU上剩余的block id
-1)) # 服务端管理的GPU上剩余的block id
heapq.heapify(self.gpu_free_block_list) heapq.heapify(self.gpu_free_block_list)
self.node_id_pool = list( self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
range(self.num_gpu_blocks + self.num_cpu_blocks))
def _enable_cpu_cache(self): def _enable_cpu_cache(self):
""" """
@@ -226,8 +228,7 @@ class PrefixCacheManager:
# port=ipc_cache_queue_port, # port=ipc_cache_queue_port,
# ) # )
# 开启获取传输任务结果的监听线程 # 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread( self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
target=self.recv_data_transfer_result)
self.transfer_recv_thread.start() self.transfer_recv_thread.start()
def allocate_gpu_blocks(self, num_blocks): def allocate_gpu_blocks(self, num_blocks):
@@ -237,9 +238,7 @@ class PrefixCacheManager:
assert num_blocks <= len( assert num_blocks <= len(
self.gpu_free_block_list self.gpu_free_block_list
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}" ), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [ allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)
]
logger.info( logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
) )
@@ -265,9 +264,7 @@ class PrefixCacheManager:
assert num_blocks <= len( assert num_blocks <= len(
self.cpu_free_block_list self.cpu_free_block_list
), f"cpu free block num: {len(self.cpu_free_block_list)} < needed number {num_blocks}" ), f"cpu free block num: {len(self.cpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [ allocated_block_ids = [heapq.heappop(self.cpu_free_block_list) for i in range(num_blocks)]
heapq.heappop(self.cpu_free_block_list) for i in range(num_blocks)
]
logger.info( logger.info(
f"allocate_cpu_blocks: {allocated_block_ids}, len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}" f"allocate_cpu_blocks: {allocated_block_ids}, len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
) )
@@ -307,16 +304,17 @@ class PrefixCacheManager:
""" """
self.task_swapping_event[transfer_task_id] = Event() self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task(( self.cache_task_queue.put_transfer_task(
(
swap_node_ids, swap_node_ids,
gpu_block_ids, gpu_block_ids,
cpu_block_ids, cpu_block_ids,
event_type, event_type,
transfer_task_id, transfer_task_id,
)) )
)
if is_sync: if is_sync:
self.sync_swap_task(transfer_task_id) self.sync_swap_task(transfer_task_id)
return
def sync_swap_task(self, transfer_task_id): def sync_swap_task(self, transfer_task_id):
""" """
@@ -325,26 +323,27 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id].wait() self.task_swapping_event[transfer_task_id].wait()
del self.task_swapping_event[transfer_task_id] del self.task_swapping_event[transfer_task_id]
def _check_validity(self, req_id, match_gpu_blocks_num, def _check_validity(self, req_id, match_gpu_blocks_num, expected_block_num):
expected_block_num):
""" """
check enough gpu memory to allocate cache check enough gpu memory to allocate cache
""" """
if expected_block_num - match_gpu_blocks_num > len( if expected_block_num - match_gpu_blocks_num > len(self.gpu_free_block_list):
self.gpu_free_block_list):
msg = ( msg = (
f"request_block_ids: request block for req_id {req_id} failed. " f"request_block_ids: request block for req_id {req_id} failed. "
+ + f"matched gpu block num: {match_gpu_blocks_num} require extra gpu block num: "
f"matched gpu block num: {match_gpu_blocks_num} require extra gpu block num: " + f"{expected_block_num - match_gpu_blocks_num} > free block num: {len(self.gpu_free_block_list)}"
+
f"{expected_block_num - match_gpu_blocks_num} > free block num: {len(self.gpu_free_block_list)}"
) )
logger.info(msg) logger.info(msg)
raise Exception("Not enough GPU memory to allocate cache") raise Exception("Not enough GPU memory to allocate cache")
def _prepare_cpu_cache(
def _prepare_cpu_cache(self, req_id, swap_node_ids, gpu_recv_block_ids, \ self,
cpu_recv_block_ids, match_cpu_block_ids): req_id,
swap_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
):
""" """
将cpu cache转移到GPU 将cpu cache转移到GPU
""" """
@@ -357,11 +356,8 @@ class PrefixCacheManager:
for tmp_cpu_block_id in match_cpu_block_ids: for tmp_cpu_block_id in match_cpu_block_ids:
need_transfer_task_cpu_block_ids.append(tmp_cpu_block_id) need_transfer_task_cpu_block_ids.append(tmp_cpu_block_id)
assert len(need_transfer_task_gpu_block_ids) == len( assert len(need_transfer_task_gpu_block_ids) == len(need_transfer_task_cpu_block_ids)
need_transfer_task_cpu_block_ids) logger.info(f"request_block_ids: req_id {req_id} issue_swap_task transfer_task_id {transfer_task_id}")
logger.info(
f"request_block_ids: req_id {req_id} issue_swap_task transfer_task_id {transfer_task_id}"
)
self.issue_swap_task( self.issue_swap_task(
transfer_task_id, transfer_task_id,
swap_node_ids, swap_node_ids,
@@ -371,8 +367,16 @@ class PrefixCacheManager:
True, True,
) )
def _prepare_cache(self, req_id, input_ids, block_size, \ def _prepare_cache(
expected_block_num, match_gpu_block_ids, match_cpu_block_ids, match_node_ids): self,
req_id,
input_ids,
block_size,
expected_block_num,
match_gpu_block_ids,
match_cpu_block_ids,
match_node_ids,
):
""" """
prepare cache for request prepare cache for request
""" """
@@ -394,8 +398,13 @@ class PrefixCacheManager:
gpu_extra_block_ids = self.allocate_gpu_blocks(gpu_extra_block_num) gpu_extra_block_ids = self.allocate_gpu_blocks(gpu_extra_block_num)
if len(gpu_recv_block_ids) > 0: if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(req_id, match_node_ids, gpu_recv_block_ids, \ self._prepare_cpu_cache(
cpu_recv_block_ids, match_cpu_block_ids) req_id,
match_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
)
return gpu_recv_block_ids, gpu_extra_block_ids return gpu_recv_block_ids, gpu_extra_block_ids
@@ -423,9 +432,7 @@ class PrefixCacheManager:
self.metrics.req_count += 1 self.metrics.req_count += 1
input_ids = task.prompt_token_ids input_ids = task.prompt_token_ids
req_id = task.request_id req_id = task.request_id
logger.info( logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}")
f"request_block_ids: start to allocate blocks for req_id {req_id}"
)
input_token_num = len(input_ids) input_token_num = len(input_ids)
common_block_ids = [] common_block_ids = []
unique_block_ids = [] unique_block_ids = []
@@ -443,34 +450,43 @@ class PrefixCacheManager:
matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num
matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num
# check enough gpu memory to allocate cache # check enough gpu memory to allocate cache
block_num = (input_token_num + block_size - 1 + block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size
dec_token_num) // block_size
self._check_validity(req_id, matched_block_num, block_num) self._check_validity(req_id, matched_block_num, block_num)
# update matched node info # update matched node info
current_time = time.time() current_time = time.time()
self._update_matched_node_info(req_id, match_block_node, self._update_matched_node_info(req_id, match_block_node, current_time)
current_time)
# 2. prepare cache # 2. prepare cache
gpu_recv_block_ids, gpu_extra_block_ids, = self._prepare_cache(req_id, \ (gpu_recv_block_ids, gpu_extra_block_ids,) = self._prepare_cache(
input_ids, block_size, block_num, match_gpu_block_ids, match_cpu_block_ids, swap_node_ids) req_id,
input_ids,
block_size,
block_num,
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
)
# update matched token num # update matched token num
matched_block_num = (gpu_match_token_num + cpu_match_token_num) matched_block_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
unique_block_ids = gpu_extra_block_ids unique_block_ids = gpu_extra_block_ids
dec_block_num = dec_token_num // block_size dec_block_num = dec_token_num // block_size
left_input_ids = input_ids[ left_input_ids = input_ids[matched_token_num_in_cpu_and_gpu:] # 没在前缀树中的token
matched_token_num_in_cpu_and_gpu:] # 没在前缀树中的token
gpu_build_path_block_ids = [] gpu_build_path_block_ids = []
gpu_build_path_block_ids = gpu_extra_block_ids gpu_build_path_block_ids = gpu_extra_block_ids
leaf_node = self.build_path(req_id, current_time, input_ids, leaf_node = self.build_path(
req_id,
current_time,
input_ids,
left_input_ids, left_input_ids,
gpu_build_path_block_ids, gpu_build_path_block_ids,
block_size, match_block_node, block_size,
dec_block_num) match_block_node,
dec_block_num,
)
self.req_leaf_map[req_id] = leaf_node self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id) self.leaf_req_map[leaf_node].add(req_id)
# 3. update metrics # 3. update metrics
@@ -482,17 +498,15 @@ class PrefixCacheManager:
gpu_match_token_num, gpu_match_token_num,
input_token_num, input_token_num,
) )
hit_info[ hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
"gpu_cache_blocks"] = gpu_match_token_num // block_size hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
hit_info[
"cpu_cache_blocks"] = cpu_match_token_num // block_size
self.metrics._update_history_hit_metrics() self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0: if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics() self.metrics.reset_metrics()
logger.info( logger.info(
f"request_block_ids: request block for req_id {req_id}: common_block_ids " f"request_block_ids: request block for req_id {req_id}: common_block_ids "
+ + f"{common_block_ids}, unique_block_ids {unique_block_ids}"
f"{common_block_ids}, unique_block_ids {unique_block_ids}") )
return common_block_ids, unique_block_ids, hit_info return common_block_ids, unique_block_ids, hit_info
except Exception as e: except Exception as e:
logger.error(f"request_block_ids: error: {type(e)} {e}") logger.error(f"request_block_ids: error: {type(e)} {e}")
@@ -523,25 +537,21 @@ class PrefixCacheManager:
node.decrement_shared_count() node.decrement_shared_count()
node = node.parent node = node.parent
logger.info( logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
f"release_block_ids: req_id {req_id} leaf_node {leaf_node}"
)
if leaf_node == self.radix_tree_root: if leaf_node == self.radix_tree_root:
self.recycle_gpu_blocks( self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id])
self.unfilled_req_block_map[req_id])
del self.unfilled_req_block_map[req_id] del self.unfilled_req_block_map[req_id]
return return
if leaf_node in self.gpu_lru_leaf_set: if leaf_node in self.gpu_lru_leaf_set:
return return
if (leaf_node.shared_count == 0 and leaf_node.is_gpu_leaf_node if leaf_node.shared_count == 0 and leaf_node.is_gpu_leaf_node and leaf_node.is_persistent is False:
and leaf_node.is_persistent is False):
self.gpu_lru_leaf_set.add(leaf_node) self.gpu_lru_leaf_set.add(leaf_node)
heapq.heappush(self.gpu_lru_leaf_heap, leaf_node) heapq.heappush(self.gpu_lru_leaf_heap, leaf_node)
logger.info( logger.info(
f"release_block_ids: req_id {req_id} has been finished, " + f"release_block_ids: req_id {req_id} has been finished, "
f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}" + f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}"
) )
return return
except Exception as e: except Exception as e:
@@ -563,8 +573,15 @@ class PrefixCacheManager:
node.reverved_dec_block_ids = [] node.reverved_dec_block_ids = []
self.recycle_gpu_blocks(node.block_id) self.recycle_gpu_blocks(node.block_id)
def _handle_free_gpu_node_with_cpu(self, node, hash_value_input_ids_map, \ def _handle_free_gpu_node_with_cpu(
hash_value_depth_map, need_recycle_gpu_block_ids, hash_value_gpu_block_ids_map, hash_value_swap_node_ids_map): self,
node,
hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
):
""" """
GPU node eviction in hierarchical cache layers GPU node eviction in hierarchical cache layers
""" """
@@ -573,14 +590,19 @@ class PrefixCacheManager:
node.reverved_dec_block_ids = [] node.reverved_dec_block_ids = []
need_recycle_gpu_block_ids.append(node.block_id) need_recycle_gpu_block_ids.append(node.block_id)
hash_value_gpu_block_ids_map[node.input_hash_value].append( hash_value_gpu_block_ids_map[node.input_hash_value].append(node.block_id)
node.block_id) hash_value_swap_node_ids_map[node.input_hash_value].append(node.node_id)
hash_value_swap_node_ids_map[node.input_hash_value].append(
node.node_id)
def _evict_cache_async(self, future, total_gpu_free_count, \ def _evict_cache_async(
hash_value_gpu_block_ids_map, hash_value_block_ids_map, \ self,
hash_value_swap_node_ids_map, hash_value_input_ids_map, hash_value_depth_map): future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
):
""" """
evict cache async (GPU --> CPU) evict cache async (GPU --> CPU)
""" """
@@ -592,23 +614,21 @@ class PrefixCacheManager:
need_transfer_task_cpu_block_ids = [] need_transfer_task_cpu_block_ids = []
cpu_block_ids = self.allocate_cpu_blocks(total_gpu_free_count) cpu_block_ids = self.allocate_cpu_blocks(total_gpu_free_count)
for input_hash_value in hash_value_gpu_block_ids_map.keys(): for input_hash_value in hash_value_gpu_block_ids_map.keys():
need_transfer_task_gpu_block_ids.extend( need_transfer_task_gpu_block_ids.extend(reversed(hash_value_gpu_block_ids_map[input_hash_value]))
reversed(hash_value_gpu_block_ids_map[input_hash_value]))
all_allocated_cpu_block_ids = [] all_allocated_cpu_block_ids = []
for _ in reversed(hash_value_gpu_block_ids_map[input_hash_value]): for _ in reversed(hash_value_gpu_block_ids_map[input_hash_value]):
cpu_block_id_t = cpu_block_ids.pop(0) cpu_block_id_t = cpu_block_ids.pop(0)
all_allocated_cpu_block_ids.append(cpu_block_id_t) all_allocated_cpu_block_ids.append(cpu_block_id_t)
need_transfer_task_cpu_block_ids.append(cpu_block_id_t) need_transfer_task_cpu_block_ids.append(cpu_block_id_t)
swap_node_ids.extend( swap_node_ids.extend(reversed(hash_value_swap_node_ids_map[input_hash_value]))
reversed(hash_value_swap_node_ids_map[input_hash_value]))
logger.info( logger.info(
"free_block_ids_async: issue transfer task: " + "free_block_ids_async: issue transfer task: "
f"transfer_task_id {transfer_task_id}: " + + f"transfer_task_id {transfer_task_id}: "
f"swap_node_ids {swap_node_ids} need_transfer_task_gpu_block_ids " + f"swap_node_ids {swap_node_ids} need_transfer_task_gpu_block_ids "
+ + f"{need_transfer_task_gpu_block_ids}, need_transfer_task_cpu_block_ids "
f"{need_transfer_task_gpu_block_ids}, need_transfer_task_cpu_block_ids " + f"{need_transfer_task_cpu_block_ids}, CacheStatus.SWAP2CPU"
+ f"{need_transfer_task_cpu_block_ids}, CacheStatus.SWAP2CPU") )
self.issue_swap_task( self.issue_swap_task(
transfer_task_id, transfer_task_id,
swap_node_ids, swap_node_ids,
@@ -619,9 +639,8 @@ class PrefixCacheManager:
) )
logger.info( logger.info(
"free_block_ids_async: after free, " + "free_block_ids_async: after free, " + f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}") )
return
def free_block_ids_async(self, need_block_num): def free_block_ids_async(self, need_block_num):
""" """
@@ -654,8 +673,10 @@ class PrefixCacheManager:
break break
node = heapq.heappop(self.gpu_lru_leaf_heap) node = heapq.heappop(self.gpu_lru_leaf_heap)
self.gpu_lru_leaf_set.remove(node) self.gpu_lru_leaf_set.remove(node)
if not self.cache_config.enable_hierarchical_cache or \ if (
self.cache_config.num_cpu_blocks < need_block_num: not self.cache_config.enable_hierarchical_cache
or self.cache_config.num_cpu_blocks < need_block_num
):
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收 if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node) self._handle_free_gpu_node_without_cpu(node)
total_gpu_free_count += 1 total_gpu_free_count += 1
@@ -666,12 +687,13 @@ class PrefixCacheManager:
if not node.children: if not node.children:
if node in self.gpu_lru_leaf_set: if node in self.gpu_lru_leaf_set:
continue continue
if (node != self.radix_tree_root if (
node != self.radix_tree_root
and node.shared_count == 0 and node.shared_count == 0
and node.is_gpu_leaf_node and node.is_gpu_leaf_node
and node.is_persistent is False): and node.is_persistent is False
heapq.heappush(self.gpu_lru_leaf_heap, ):
node) heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node) self.gpu_lru_leaf_set.add(node)
else: else:
continue continue
@@ -680,18 +702,25 @@ class PrefixCacheManager:
node.cache_status = CacheStatus.SWAP2CPU node.cache_status = CacheStatus.SWAP2CPU
else: else:
continue continue
self._handle_free_gpu_node_with_cpu(node, hash_value_input_ids_map, \ self._handle_free_gpu_node_with_cpu(
hash_value_depth_map, need_recycle_gpu_block_ids, \ node,
hash_value_gpu_block_ids_map, hash_value_swap_node_ids_map) hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
)
total_gpu_free_count += 1 total_gpu_free_count += 1
node = node.parent node = node.parent
if node in self.gpu_lru_leaf_set: if node in self.gpu_lru_leaf_set:
continue continue
if (node != self.radix_tree_root if (
node != self.radix_tree_root
and node.shared_count == 0 and node.shared_count == 0
and node.is_gpu_leaf_node and node.is_gpu_leaf_node
and node.is_persistent is False): and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node) heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node) self.gpu_lru_leaf_set.add(node)
@@ -702,12 +731,16 @@ class PrefixCacheManager:
cpu_free_count = total_gpu_free_count cpu_free_count = total_gpu_free_count
if cpu_free_count < need_block_num: if cpu_free_count < need_block_num:
cpu_free_count = need_block_num cpu_free_count = need_block_num
cpu_free_future = self.free_cpu_executor_pool.submit( cpu_free_future = self.free_cpu_executor_pool.submit(self.free_cpu_block_ids, cpu_free_count)
self.free_cpu_block_ids, cpu_free_count)
self.gpu_free_task_future = self.free_gpu_executor_pool.submit( self.gpu_free_task_future = self.free_gpu_executor_pool.submit(
self._evict_cache_async, cpu_free_future, total_gpu_free_count, \ self._evict_cache_async,
hash_value_gpu_block_ids_map, hash_value_block_ids_map, \ cpu_free_future,
hash_value_swap_node_ids_map, hash_value_input_ids_map, hash_value_depth_map total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
) )
else: else:
self.gpu_free_task_future = None self.gpu_free_task_future = None
@@ -724,10 +757,7 @@ class PrefixCacheManager:
Returns: Returns:
- freed_block_num: Number of CPU blocks successfully evicted - freed_block_num: Number of CPU blocks successfully evicted
""" """
hash_value_input_ids_map = {}
hash_value_block_ids_map = defaultdict(list) hash_value_block_ids_map = defaultdict(list)
hash_value_depth_map = {}
need_recycle_cpu_block_ids = []
total_cpu_free_count = 0 total_cpu_free_count = 0
with self.request_release_lock: with self.request_release_lock:
while True: while True:
@@ -739,13 +769,10 @@ class PrefixCacheManager:
node = heapq.heappop(self.cpu_lru_leaf_heap) node = heapq.heappop(self.cpu_lru_leaf_heap)
self.cpu_lru_leaf_set.remove(node) self.cpu_lru_leaf_set.remove(node)
tmp_block_ids = [] tmp_block_ids = []
if (node.shared_count == 0 if node.shared_count == 0 and node.cache_status == CacheStatus.CPU and node.is_cpu_leaf_node:
and node.cache_status == CacheStatus.CPU
and node.is_cpu_leaf_node):
self.recycle_cpu_blocks(node.block_id) self.recycle_cpu_blocks(node.block_id)
hash_value_block_ids_map[node.input_hash_value].extend( hash_value_block_ids_map[node.input_hash_value].extend(reversed(tmp_block_ids))
reversed(tmp_block_ids))
logger.info(f"free_cpu_block_ids: free node {node}") logger.info(f"free_cpu_block_ids: free node {node}")
self.node_id_pool.append(node.node_id) self.node_id_pool.append(node.node_id)
@@ -759,15 +786,17 @@ class PrefixCacheManager:
if not node.children: if not node.children:
if node in self.cpu_lru_leaf_set: if node in self.cpu_lru_leaf_set:
continue continue
if (node != self.radix_tree_root if (
node != self.radix_tree_root
and node.shared_count == 0 and node.shared_count == 0
and node.is_cpu_leaf_node and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU): and node.cache_status == CacheStatus.CPU
):
heapq.heappush(self.cpu_lru_leaf_heap, node) heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node) self.cpu_lru_leaf_set.add(node)
logger.info( logger.info(
"free_cpu_block_ids: after free, " + "free_cpu_block_ids: after free, " + f"len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
f"len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}") )
return total_cpu_free_count return total_cpu_free_count
def cal_block_hash(self, block): def cal_block_hash(self, block):
@@ -807,8 +836,7 @@ class PrefixCacheManager:
with self.cache_status_lock: with self.cache_status_lock:
while match_token_num < total_token_num: while match_token_num < total_token_num:
token_block = input_ids[match_token_num:match_token_num + token_block = input_ids[match_token_num : match_token_num + block_size]
block_size]
token_num = len(token_block) token_num = len(token_block)
if token_num != block_size: if token_num != block_size:
break break
@@ -817,11 +845,11 @@ class PrefixCacheManager:
child = current_match_node.children[hash_value] child = current_match_node.children[hash_value]
matche_nodes.append(child) matche_nodes.append(child)
match_node_ids.append(child.node_id) match_node_ids.append(child.node_id)
if (child in self.gpu_lru_leaf_set): if child in self.gpu_lru_leaf_set:
self.gpu_lru_leaf_set.remove(child) self.gpu_lru_leaf_set.remove(child)
self.gpu_lru_leaf_heap.remove(child) self.gpu_lru_leaf_heap.remove(child)
has_modified_gpu_lru_leaf_heap = True has_modified_gpu_lru_leaf_heap = True
elif (child in self.cpu_lru_leaf_set): elif child in self.cpu_lru_leaf_set:
self.cpu_lru_leaf_set.remove(child) self.cpu_lru_leaf_set.remove(child)
self.cpu_lru_leaf_heap.remove(child) self.cpu_lru_leaf_heap.remove(child)
has_modified_cpu_lru_leaf_heap = True has_modified_cpu_lru_leaf_heap = True
@@ -831,8 +859,9 @@ class PrefixCacheManager:
else: else:
if child.cache_status == CacheStatus.SWAP2CPU: if child.cache_status == CacheStatus.SWAP2CPU:
logger.info( logger.info(
f"match_block: req_id {req_id} matched node" + f"match_block: req_id {req_id} matched node"
f" {child.node_id} which is being SWAP2CPU") + f" {child.node_id} which is being SWAP2CPU"
)
child.cache_status = CacheStatus.GPU child.cache_status = CacheStatus.GPU
match_gpu_block_ids.append(child.block_id) match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size gpu_match_token_num += block_size
@@ -851,8 +880,7 @@ class PrefixCacheManager:
if has_modified_cpu_lru_leaf_heap: if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap) heapq.heapify(self.cpu_lru_leaf_heap)
logger.info( logger.info(f"match_block: req_id {req_id} matched nodes: {match_node_ids}")
f"match_block: req_id {req_id} matched nodes: {match_node_ids}")
return ( return (
match_gpu_block_ids, match_gpu_block_ids,
match_cpu_block_ids, match_cpu_block_ids,
@@ -873,9 +901,17 @@ class PrefixCacheManager:
node.req_id_set.add(req_id) node.req_id_set.add(req_id)
node = node.parent node = node.parent
def build_path(self, req_id, current_time, input_ids, left_input_ids, def build_path(
gpu_block_ids, block_size, last_node, self,
reverved_dec_block_num): req_id,
current_time,
input_ids,
left_input_ids,
gpu_block_ids,
block_size,
last_node,
reverved_dec_block_num,
):
""" """
Build path for blocks beyond the common prefix Build path for blocks beyond the common prefix
Parameters: Parameters:
@@ -915,7 +951,8 @@ class PrefixCacheManager:
allocated_block_id = gpu_block_ids.pop(0) allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop() node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id) unique_node_ids.append(node_id)
new_last_node = BlockNode(node_id, new_last_node = BlockNode(
node_id,
input_ids, input_ids,
input_hash_value, input_hash_value,
node.depth + 1, node.depth + 1,
@@ -925,7 +962,8 @@ class PrefixCacheManager:
current_time, current_time,
parent=node, parent=node,
shared_count=1, shared_count=1,
reverved_dec_block_ids=[]) reverved_dec_block_ids=[],
)
new_last_node.req_id_set.add(req_id) new_last_node.req_id_set.add(req_id)
self.node_map[node_id] = new_last_node self.node_map[node_id] = new_last_node
node.children[hash_value] = new_last_node node.children[hash_value] = new_last_node
@@ -939,46 +977,44 @@ class PrefixCacheManager:
self.unfilled_req_block_map[req_id] = reverved_dec_block_ids self.unfilled_req_block_map[req_id] = reverved_dec_block_ids
else: else:
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids) new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
logger.info( logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {req_id}")
f"build_path: allocate unique node ids {unique_node_ids} for req_id {req_id}"
)
return new_last_node return new_last_node
def _handle_swap_result(self, swap_node_id, task_gpu_block_id, def _handle_swap_result(self, swap_node_id, task_gpu_block_id, task_cpu_block_id, event_type):
task_cpu_block_id, event_type):
""" """
handle swap resuha handle swap resuha
""" """
if swap_node_id is None: if swap_node_id is None:
return return
with self.cache_status_lock: with self.cache_status_lock:
if (event_type.value == CacheStatus.SWAP2CPU.value): if event_type.value == CacheStatus.SWAP2CPU.value:
gpu_block_id = task_gpu_block_id gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id cpu_block_id = task_cpu_block_id
node = self.node_map[swap_node_id] node = self.node_map[swap_node_id]
if node.cache_status.value == CacheStatus.GPU.value: if node.cache_status.value == CacheStatus.GPU.value:
logger.info( logger.info(
f"recv_data_transfer_result: node {node.node_id} " + f"recv_data_transfer_result: node {node.node_id} "
f"has been reused when SWAP2CPU, recycle cpu block id {cpu_block_id}" + f"has been reused when SWAP2CPU, recycle cpu block id {cpu_block_id}"
) )
self.recycle_cpu_blocks(cpu_block_id) self.recycle_cpu_blocks(cpu_block_id)
else: else:
node.cache_status = CacheStatus.CPU node.cache_status = CacheStatus.CPU
node.block_id = cpu_block_id node.block_id = cpu_block_id
if (node != self.radix_tree_root and node.shared_count == 0 if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_cpu_leaf_node and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU): and node.cache_status == CacheStatus.CPU
):
if node not in self.cpu_lru_leaf_set: if node not in self.cpu_lru_leaf_set:
heapq.heappush(self.cpu_lru_leaf_heap, node) heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node) self.cpu_lru_leaf_set.add(node)
self.recycle_gpu_blocks(gpu_block_id) self.recycle_gpu_blocks(gpu_block_id)
logger.info( logger.info(f"recv_data_transfer_result: after SWAP2CPU, node {node}")
f"recv_data_transfer_result: after SWAP2CPU, node {node}"
)
elif (event_type.value == CacheStatus.SWAP2GPU.value): elif event_type.value == CacheStatus.SWAP2GPU.value:
gpu_block_id = task_gpu_block_id gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id cpu_block_id = task_cpu_block_id
@@ -987,12 +1023,12 @@ class PrefixCacheManager:
node.block_id = gpu_block_id node.block_id = gpu_block_id
self.recycle_cpu_blocks(cpu_block_id) self.recycle_cpu_blocks(cpu_block_id)
logger.info( logger.info(f"recv_data_transfer_result: after SWAP2GPU, node {node}")
f"recv_data_transfer_result: after SWAP2GPU, node {node}")
else: else:
logger.warning( logger.warning(
f"recv_data_transfer_result: Get unexpected event type {event_type}" f"recv_data_transfer_result: Get unexpected event type {event_type}"
+ ", only SWAP2CPU and SWAP2GPU supported") + ", only SWAP2CPU and SWAP2GPU supported"
)
def recv_data_transfer_result(self): def recv_data_transfer_result(self):
""" """
@@ -1024,10 +1060,8 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id].set() self.task_swapping_event[transfer_task_id].set()
logger.info( logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: " f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ + f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} " + f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
+
f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
) )
except Exception as e: except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}") logger.warning(f"recv_data_transfer_result: error: {e}")

View File

@@ -13,5 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
from .ipc_cache_transfer import IPCCommManager from .ipc_cache_transfer import IPCCommManager
from .rdma_cache_transfer import RDMACommManager from .rdma_cache_transfer import RDMACommManager
__all__ = ["IPCCommManager", "RDMACommManager"]

View File

@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
""" """
import os
import paddle import paddle
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
get_data_ptr_ipc, ipc_sent_key_value_cache_by_remote_ptr, get_data_ptr_ipc,
ipc_sent_key_value_cache_by_remote_ptr_block_sync) ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
)
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log") logger = get_logger("cache_messager", "cache_messager.log")
@@ -44,17 +44,13 @@ class IPCConnector:
self.rank_id = rank_id_ self.rank_id = rank_id_
self.local_gpu_id = int(local_gpu_id_) self.local_gpu_id = int(local_gpu_id_)
tmp = paddle.ones([1, 1]) tmp = paddle.ones([1, 1])
logger.info( logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}")
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}"
)
for layer_id in range(layer_num): for layer_id in range(layer_num):
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}" key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}" value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
self.remote_key_tensor_ptr_list.append( self.remote_key_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_unique_name))
get_data_ptr_ipc(tmp, key_unique_name)) self.remote_value_tensor_ptr_list.append(get_data_ptr_ipc(tmp, value_unique_name))
self.remote_value_tensor_ptr_list.append( self.write_stream = paddle.device.Stream(f"gpu:{self.local_gpu_id}")
get_data_ptr_ipc(tmp, value_unique_name))
self.write_stream = paddle.device.Stream(f'gpu:{self.local_gpu_id}')
self.finish_event = paddle.device.Event() self.finish_event = paddle.device.Event()
@@ -83,14 +79,11 @@ class IPCCommManager:
""" """
Connect to remote gpu. Connect to remote gpu.
""" """
logger.info( logger.info(f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}")
f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}"
)
if self.is_connected(remote_gpu_id_): if self.is_connected(remote_gpu_id_):
return True return True
else: else:
self.comm_map[remote_gpu_id_] = IPCConnector( self.comm_map[remote_gpu_id_] = IPCConnector(self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
return True return True
def is_connected(self, remote_gpu_id_=0): def is_connected(self, remote_gpu_id_=0):
@@ -102,8 +95,7 @@ class IPCCommManager:
else: else:
return False return False
def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids, def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids, layer_idx):
layer_idx):
""" """
Connect to remote gpu and write cache. Connect to remote gpu and write cache.
""" """
@@ -114,20 +106,26 @@ class IPCCommManager:
with paddle.device.stream_guard(comm.write_stream): with paddle.device.stream_guard(comm.write_stream):
ipc_sent_key_value_cache_by_remote_ptr( ipc_sent_key_value_cache_by_remote_ptr(
self.local_key_cache_tensor_list[layer_idx], self.local_key_cache_tensor_list[layer_idx],
self.local_value_cache_tensor_list[layer_idx], local_block_ids, self.local_value_cache_tensor_list[layer_idx],
remote_block_ids, comm.remote_key_tensor_ptr_list[layer_idx], local_block_ids,
comm.remote_value_tensor_ptr_list[layer_idx], block_num, remote_block_ids,
self.gpu_idx, comm.remote_gpu_id, comm.remote_key_tensor_ptr_list[layer_idx],
comm.write_stream.stream_base.cuda_stream) comm.remote_value_tensor_ptr_list[layer_idx],
block_num,
self.gpu_idx,
comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream,
)
return 0 return 0
def write_block_by_sync(self, remote_gpu_id): def write_block_by_sync(self, remote_gpu_id):
""" """
check finish event and wait for it check finish event and wait for it
""" """
paddle.set_device(f'gpu:{self.gpu_idx}') paddle.set_device(f"gpu:{self.gpu_idx}")
comm = self.comm_map[remote_gpu_id] comm = self.comm_map[remote_gpu_id]
ipc_sent_key_value_cache_by_remote_ptr_block_sync( ipc_sent_key_value_cache_by_remote_ptr_block_sync(
self.local_key_cache_tensor_list[0], # tensor no use self.local_key_cache_tensor_list[0], # tensor no use
self.local_value_cache_tensor_list[0], # tensor no use self.local_value_cache_tensor_list[0], # tensor no use
comm.write_stream.stream_base.cuda_stream) comm.write_stream.stream_base.cuda_stream,
)

View File

@@ -42,11 +42,13 @@ Bandwidth Saturation Capability: Under multi-threaded high-pressure scenarios, b
### Dependencies Installation ### Dependencies Installation
#### Python Packages #### Python Packages
```bash ```bash
pip install pyzmq pybind11[global] pip install pyzmq pybind11[global]
``` ```
#### System Libraries (Linux) #### System Libraries (Linux)
```bash ```bash
# Ubuntu/Debian # Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev sudo apt-get install -y libibverbs-dev librdmacm-dev
@@ -107,8 +109,6 @@ pip install dist/*.whl
|----------|---------|-------------| |----------|---------|-------------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | Enable GDRCopy flush for Ampere GPUs | | `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | Enable GDRCopy flush for Ampere GPUs |
# Set RDMA GID index # Set RDMA GID index
export KVCACHE_RDMA_GID_INDEX=3 export KVCACHE_RDMA_GID_INDEX=3
@@ -125,7 +125,6 @@ export KVCACHE_DEBUG=1
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## Network configurations ## Network configurations
kvcache transfer is fully tested with RDMA over Converged Ethernet (RoCE) networks. However, it is theoretically compatible with Infiniband as well. kvcache transfer is fully tested with RDMA over Converged Ethernet (RoCE) networks. However, it is theoretically compatible with Infiniband as well.

View File

@@ -43,11 +43,13 @@
### 依赖安装 ### 依赖安装
#### Python包 #### Python包
```bash ```bash
pip install pyzmq pybind11[global] pip install pyzmq pybind11[global]
``` ```
#### 系统库(Linux) #### 系统库(Linux)
```bash ```bash
# Ubuntu/Debian # Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev sudo apt-get install -y libibverbs-dev librdmacm-dev
@@ -108,7 +110,6 @@ pip install dist/*.whl
|------|--------|------| |------|--------|------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | 为Ampere GPU启用GDRCopy刷新 | | `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | 为Ampere GPU启用GDRCopy刷新 |
# 设置RDMA GID索引 # 设置RDMA GID索引
export KVCACHE_RDMA_GID_INDEX=3 export KVCACHE_RDMA_GID_INDEX=3
@@ -125,7 +126,6 @@ export KVCACHE_DEBUG=1
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## 网络配置 ## 网络配置
kvcache transfer已通过RDMA over Converged Ethernet (RoCE)网络全面测试。理论上也兼容Infiniband。 kvcache transfer已通过RDMA over Converged Ethernet (RoCE)网络全面测试。理论上也兼容Infiniband。

View File

@@ -24,13 +24,24 @@ class RDMACommManager:
RDMACommManager to manage rdma communication RDMACommManager to manage rdma communication
""" """
def __init__(self, splitwise_role, rank, gpu_id, cache_k_ptr_list, \ def __init__(
cache_v_ptr_list, max_block_num, block_bytes, rdma_port): self,
splitwise_role,
rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
):
try: try:
import rdma_comm import rdma_comm
except: except:
logger.error(f"The installation of the RDMA library failed." \ logger.error(
"Confirm whether your network card supports RDMA transmission.") "The installation of the RDMA library failed."
"Confirm whether your network card supports RDMA transmission."
)
return return
self.messager = rdma_comm.RDMACommunicator( self.messager = rdma_comm.RDMACommunicator(
splitwise_role, splitwise_role,
@@ -50,7 +61,7 @@ class RDMACommManager:
Connect to remote gpu and write cache. Connect to remote gpu and write cache.
""" """
assert self.splitwise_role == "prefill", "only prefill can call this method" assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{str(port)}" addr = f"{ip}:{port!s}"
if addr in self.connected_rdma: if addr in self.connected_rdma:
return True return True
ret = self.messager.is_connected(ip, str(port)) ret = self.messager.is_connected(ip, str(port))
@@ -59,18 +70,13 @@ class RDMACommManager:
return True return True
ret = self.messager.connect(ip, str(port)) ret = self.messager.connect(ip, str(port))
logger.info( logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0: if ret == 0:
self.connected_rdma.add(addr) self.connected_rdma.add(addr)
return ret == 0 return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids, def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
layer_idx):
""" """
Connect to remote gpu and write cache. Connect to remote gpu and write cache.
""" """
return self.messager.write_cache(ip, str(port), local_block_ids, return self.messager.write_cache(ip, str(port), local_block_ids, remote_block_ids, layer_idx)
remote_block_ids, layer_idx)

View File

@@ -24,12 +24,12 @@ from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.transformers.configuration_utils import PretrainedConfig
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import \ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
QuantConfigBase
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log") logger = get_logger("config", "config.log")
class MoEPhase(Enum): class MoEPhase(Enum):
""" """
The generation phase of the moe. The generation phase of the moe.
@@ -38,13 +38,14 @@ class MoEPhase(Enum):
PREFILL = 1 PREFILL = 1
DECODER = 2 DECODER = 2
class ErnieArchitectures: class ErnieArchitectures:
"""Helper class for ERNIE architecture check.""" """Helper class for ERNIE architecture check."""
ARCHITECTURES = { ARCHITECTURES = {
"Ernie4_5_ForCausalLM", "Ernie4_5_ForCausalLM",
"Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeForCausalLM",
"Ernie4_5_VLMoeForConditionalGeneration" "Ernie4_5_VLMoeForConditionalGeneration",
} }
@classmethod @classmethod
@@ -57,6 +58,7 @@ class ErnieArchitectures:
"""Check if the given architecture is an ERNIE architecture.""" """Check if the given architecture is an ERNIE architecture."""
return architecture in cls.ARCHITECTURES return architecture in cls.ARCHITECTURES
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"rope_theta": 10000.0, "rope_theta": 10000.0,
"num_key_value_heads": -1, "num_key_value_heads": -1,
@@ -81,6 +83,7 @@ class ModelConfig:
""" """
The configuration class to store the configuration of a `LLM`. The configuration class to store the configuration of a `LLM`.
""" """
def __init__( def __init__(
self, self,
args, args,
@@ -134,6 +137,7 @@ class ModelConfig:
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.""" """Configuration for the distributed execution."""
def __init__( def __init__(
self, self,
args, args,
@@ -213,10 +217,8 @@ class ParallelConfig:
self.enable_custom_all_reduce: bool = False self.enable_custom_all_reduce: bool = False
# pd_disaggregation # pd_disaggregation
use_pd_disaggregation: int = int( use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
use_pd_disaggregation_per_chunk: int = int(
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk: if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk" self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation: elif use_pd_disaggregation:
@@ -224,10 +226,12 @@ class ParallelConfig:
else: else:
self.pd_disaggregation_mode = "None" self.pd_disaggregation_mode = "None"
class SpeculativeConfig: class SpeculativeConfig:
""" """
Configuration for speculative decoding. Configuration for speculative decoding.
""" """
def __init__( def __init__(
self, self,
args, args,
@@ -263,20 +267,24 @@ class SpeculativeConfig:
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig. # TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
# We temperately add the name map here and will delete it in future. # We temperately add the name map here and will delete it in future.
name_map = {"speculative_method": "method", name_map = {
"speculative_method": "method",
"speculative_max_draft_token_num": "num_speculative_tokens", "speculative_max_draft_token_num": "num_speculative_tokens",
"speculative_model_name_or_path": "model_name_or_path", "speculative_model_name_or_path": "model_name_or_path",
"speculative_model_quantization": "quantization", "speculative_model_quantization": "quantization",
"speculative_benchmark_mode": "benchmark_mode"} "speculative_benchmark_mode": "benchmark_mode",
}
for key, value in args.items(): for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]): if key in name_map.keys() and hasattr(self, name_map[key]):
setattr(self, name_map[key], value) setattr(self, name_map[key], value)
class DeviceConfig: class DeviceConfig:
""" """
Configuration for device settings. Configuration for device settings.
""" """
def __init__( def __init__(
self, self,
args, args,
@@ -286,6 +294,7 @@ class DeviceConfig:
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@dataclass @dataclass
class GraphOptimizationConfig: class GraphOptimizationConfig:
""" """
@@ -336,15 +345,10 @@ class GraphOptimizationConfig:
full_cuda_graph: bool = True full_cuda_graph: bool = True
max_capture_size: int = field(default=None, init=False) # type: ignore max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int, batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore
int] = field(default=None,
init=False) # type: ignore
# CINN Config ... # CINN Config ...
def init_with_cudagrpah_size( def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
self,
max_num_seqs:int = 0
) -> None:
""" """
Initialize cuda graph capture sizes and Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size pre-compute the mapping from batch size to padded graph size
@@ -353,32 +357,28 @@ class GraphOptimizationConfig:
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs] self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
dedup_sizes = list(set(self.cudagraph_capture_sizes)) dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes): if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner" logger.info(
" %s is overridden by config %s"), ("cudagraph sizes specified by model runner" " %s is overridden by config %s"),
self.cudagraph_capture_sizes, dedup_sizes) self.cudagraph_capture_sizes,
dedup_sizes,
)
self.cudagraph_capture_sizes = dedup_sizes self.cudagraph_capture_sizes = dedup_sizes
# Sort to make sure cudagraph capture sizes are in descending order # Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True) self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[ self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
0] if self.cudagraph_capture_sizes else 0
# Pre-compute the mapping from batch size to padded graph size # Pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {} self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes, for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]):
self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end): for bs in range(start, end):
if bs == start: if bs == start:
self.batch_size_to_captured_size[bs] = start self.batch_size_to_captured_size[bs] = start
else: else:
self.batch_size_to_captured_size[bs] = end self.batch_size_to_captured_size[bs] = end
self.batch_size_to_captured_size[ self.batch_size_to_captured_size[self.max_capture_size] = self.max_capture_size
self.max_capture_size] = self.max_capture_size
def _set_cudagraph_sizes( def _set_cudagraph_sizes(self, max_num_seqs: int = 0):
self,
max_num_seqs:int = 0
):
""" """
Calculate a series of candidate capture batch sizes, Calculate a series of candidate capture batch sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input. and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -405,24 +405,28 @@ class LoadConfig:
- 'ipc_snapshot': Load from disk snapshot of IPC weights - 'ipc_snapshot': Load from disk snapshot of IPC weights
- None: No dynamic loading - None: No dynamic loading
""" """
def __init__( def __init__(
self, self,
args, args,
): ):
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal['ipc', 'ipc_snapshot']] = None self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
class LoRAConfig: class LoRAConfig:
"""LoRA Config""" """LoRA Config"""
pass pass
class KVCacheConfig: class KVCacheConfig:
"""KV Cache Config""" """KV Cache Config"""
cache_quant_dtype: str = "none" cache_quant_dtype: str = "none"
@@ -430,6 +434,7 @@ class DecodingConfig:
""" """
Configuration for decoding Configuration for decoding
""" """
def __init__( def __init__(
self, self,
args, args,
@@ -439,26 +444,24 @@ class DecodingConfig:
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@dataclass @dataclass
class FDConfig: class FDConfig:
""" """
The configuration class which contains all fastdeploy-related configuration. This The configuration class which contains all fastdeploy-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
""" """
model_config: ModelConfig = field(default=None, init=True) # type: ignore model_config: ModelConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None, init=True) parallel_config: ParallelConfig = field(default=None, init=True)
speculative_config: SpeculativeConfig = field(default=None, speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore
init=True) # type: ignore device_config: DeviceConfig = field(default=None, init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None graph_opt_config: Optional[GraphOptimizationConfig] = None
decoding_config: DecodingConfig = field(default=None, decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
init=True) # type: ignore kv_cache_config: KVCacheConfig = field(default=None, init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore
def __post_init__(self): def __post_init__(self):
# Initialize cuda graph capture list # Initialize cuda graph capture list

View File

@@ -22,8 +22,6 @@ model_name_or_path = "./models/llama-7b"
# 超参设置 # 超参设置
sampling_params = SamplingParams(temperature=0.1, max_tokens=30) sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=1) llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.generate(prompts="who are you", output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
use_tqdm=True,
sampling_params=sampling_params)
print(output) print(output)

View File

@@ -14,19 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
import time
import os
import multiprocessing import multiprocessing
import os
import time
from fastdeploy.entrypoints.llm import LLM from fastdeploy.entrypoints.llm import LLM
from fastdeploy.engine.sampling_params import SamplingParams
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle" model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
def start_decode(model_name_or_path): def start_decode(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["FD_LOG_DIR"] = "log_decode" os.environ["FD_LOG_DIR"] = "log_decode"
@@ -36,14 +32,15 @@ def start_decode(model_name_or_path):
splitwise_role="decode", splitwise_role="decode",
engine_worker_queue_port=6678, engine_worker_queue_port=6678,
innode_prefill_ports=[6676], innode_prefill_ports=[6676],
cache_queue_port=55668 cache_queue_port=55668,
) )
return llm_decode return llm_decode
def start_prefill(model_name_or_path): def start_prefill(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FD_LOG_DIR"] = "log_prefill" os.environ["FD_LOG_DIR"] = "log_prefill"
llm_prefill = LLM( LLM(
model=model_name_or_path, model=model_name_or_path,
tensor_parallel_size=1, tensor_parallel_size=1,
splitwise_role="prefill", splitwise_role="prefill",
@@ -53,16 +50,14 @@ def start_prefill(model_name_or_path):
def main(): def main():
prefill = multiprocessing.Process( prefill = multiprocessing.Process(target=start_prefill, args=(model_name_or_path,)).start()
target=start_prefill,
args=(model_name_or_path,)).start()
time.sleep(10) time.sleep(10)
llm_decode = start_decode(model_name_or_path) llm_decode = start_decode(model_name_or_path)
output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True) output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True)
print(output) print(output)
decode.join() prefill.join()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" """
import openai import openai
ip = "0.0.0.0" ip = "0.0.0.0"
@@ -42,7 +41,7 @@ response = client.completions.create(
) )
for chunk in response: for chunk in response:
print(chunk.choices[0].text, end='') print(chunk.choices[0].text, end="")
print("\n") print("\n")
# Chat completion # Chat completion
@@ -78,5 +77,5 @@ response = client.chat.completions.create(
for chunk in response: for chunk in response:
if chunk.choices[0].delta is not None: if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta.content, end='') print(chunk.choices[0].delta.content, end="")
print("\n") print("\n")

View File

@@ -14,14 +14,12 @@
# limitations under the License. # limitations under the License.
""" """
import openai import openai
print("hello") print("hello")
ip = "0.0.0.0" ip = "0.0.0.0"
service_http_port = "9809" service_http_port = "9809"
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
api_key="EMPTY_API_KEY")
print("world") print("world")
# 非流式对话 # 非流式对话
@@ -30,23 +28,21 @@ response = client.chat.completions.create(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "You are a helpful AI assistant." "content": "You are a helpful AI assistant.",
}, # system不是必需可选 }, # system不是必需可选
{ {
"role": "role": "user",
"user", "content": [
"content": [{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0", "detail": "high",
"detail": "high" },
} },
}, { {"type": "text", "text": "请描述图片内容"},
"type": "text", ],
"text": "请描述图片内容" },
}]
}
], ],
temperature=1, temperature=1,
max_tokens=53, max_tokens=53,
@@ -60,30 +56,25 @@ response = client.chat.completions.create(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "You are a helpful AI assistant." "content": "You are a helpful AI assistant.",
}, # system不是必需可选 }, # system不是必需可选
{ {"role": "user", "content": "List 3 countries and their capitals."},
"role": "user",
"content": "List 3 countries and their capitals."
},
{ {
"role": "assistant", "role": "assistant",
"content": "China(Beijing), France(Paris), Australia(Canberra)." "content": "China(Beijing), France(Paris), Australia(Canberra).",
}, },
{ {
"role": "role": "user",
"user", "content": [
"content": [{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0", "detail": "high",
"detail": "high" },
} },
}, { {"type": "text", "text": "请描述图片内容"},
"type": "text", ],
"text": "请描述图片内容"
}]
}, },
], ],
temperature=1, temperature=1,
@@ -94,5 +85,5 @@ for chunk in response:
if chunk.choices[0].delta is not None: if chunk.choices[0].delta is not None:
# print(chunk.choices[0].delta, end='') # print(chunk.choices[0].delta, end='')
# print("\n") # print("\n")
print(chunk.choices[0].delta.content, end='') print(chunk.choices[0].delta.content, end="")
print(response) print(response)

View File

@@ -17,21 +17,28 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
_TP_AR = None _TP_AR = None
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group() model_parallel_group = hcg.get_model_parallel_group()
global _TP_AR global _TP_AR
if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda(): if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda():
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes) _TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
try: try:
@paddle.jit.marker.unified @paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor: def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
global _TP_AR global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_): if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
@@ -42,5 +49,6 @@ try:
dist.all_reduce(input_, group=mp_group) dist.all_reduce(input_, group=mp_group)
else: else:
dist.all_reduce(input_) dist.all_reduce(input_)
except: except:
tensor_model_parallel_all_reduce = None tensor_model_parallel_all_reduce = None

View File

@@ -41,7 +41,7 @@ def find_loaded_library(lib_name) -> Optional[str]:
the file `/proc/self/maps` contains the memory maps of the process, which includes the the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the shared libraries loaded by the process. We can use this file to find the path of the
a loaded library. a loaded library.
""" # noqa """
found = False found = False
with open("/proc/self/maps") as f: with open("/proc/self/maps") as f:
for line in f: for line in f:
@@ -73,18 +73,40 @@ class CudaRTLibrary:
# const char* cudaGetErrorString ( cudaError_t error ) # const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size ) # cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# cudaError_t cudaFree ( void* devPtr ) # cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]), Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) # cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function( Function(
"cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint] "cudaMemset",
cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t],
),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind )
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr )
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags )
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[
ctypes.POINTER(ctypes.c_void_p),
cudaIpcMemHandle_t,
ctypes.c_uint,
],
), ),
] ]

View File

@@ -13,26 +13,26 @@
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
import atexit import atexit
import ctypes import ctypes
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.communication.group import Group from paddle.distributed.communication.group import Group
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
all_reduce, all_reduce,
dispose, dispose,
get_graph_buffer_ipc_meta,
init_custom_all_reduce, init_custom_all_reduce,
meta_size, meta_size,
register_buffer, register_buffer,
get_graph_buffer_ipc_meta,
register_graph_buffers, register_graph_buffers,
) )
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
try: try:
meta_size() meta_size()
custom_ar = True custom_ar = True
@@ -147,7 +147,12 @@ class CustomAllreduce:
return inp_size < self.max_size return inp_size < self.max_size
return False return False
def all_reduce(self, inp: paddle.Tensor, out: paddle.Tensor = None, registered: bool = False): def all_reduce(
self,
inp: paddle.Tensor,
out: paddle.Tensor = None,
registered: bool = False,
):
"""Performs an out-of-place all reduce. """Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already If registered is True, this assumes inp's pointer is already
@@ -179,16 +184,12 @@ class CustomAllreduce:
def register_graph_buffers(self): def register_graph_buffers(self):
handle, offset = get_graph_buffer_ipc_meta(self._ptr) handle, offset = get_graph_buffer_ipc_meta(self._ptr)
all_data = [[None, None] all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset] all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group)) ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks): for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i], dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists. # Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore handles = [d[0] for d in all_data] # type: ignore

View File

@@ -13,15 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastdeploy.engine.config import (CacheConfig, Config, from fastdeploy.engine.config import (
GraphOptimizationConfig, ModelConfig, CacheConfig,
ParallelConfig, SpeculativeConfig, Config,
TaskOption) GraphOptimizationConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
TaskOption,
)
from fastdeploy.scheduler.config import SchedulerConfig from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import FlexibleArgumentParser from fastdeploy.utils import FlexibleArgumentParser
@@ -323,365 +329,429 @@ class EngineArgs:
""" """
# Model parameters group # Model parameters group
model_group = parser.add_argument_group("Model Configuration") model_group = parser.add_argument_group("Model Configuration")
model_group.add_argument("--model", model_group.add_argument(
"--model",
type=str, type=str,
default=EngineArgs.model, default=EngineArgs.model,
help="Model name or path to be used.") help="Model name or path to be used.",
model_group.add_argument("--model-config-name", )
model_group.add_argument(
"--model-config-name",
type=nullable_str, type=nullable_str,
default=EngineArgs.model_config_name, default=EngineArgs.model_config_name,
help="The model configuration file name.") help="The model configuration file name.",
)
model_group.add_argument( model_group.add_argument(
"--tokenizer", "--tokenizer",
type=nullable_str, type=nullable_str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help= help="Tokenizer name or path (defaults to model path if not specified).",
"Tokenizer name or path (defaults to model path if not specified)."
) )
model_group.add_argument( model_group.add_argument(
"--max-model-len", "--max-model-len",
type=int, type=int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
help="Maximum context length supported by the model.") help="Maximum context length supported by the model.",
)
model_group.add_argument( model_group.add_argument(
"--block-size", "--block-size",
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
help="Number of tokens processed in one block.") help="Number of tokens processed in one block.",
model_group.add_argument("--task", )
model_group.add_argument(
"--task",
type=str, type=str,
default=EngineArgs.task, default=EngineArgs.task,
help="Task to be executed by the model.") help="Task to be executed by the model.",
)
model_group.add_argument( model_group.add_argument(
"--use-warmup", "--use-warmup",
type=int, type=int,
default=EngineArgs.use_warmup, default=EngineArgs.use_warmup,
help="Flag to indicate whether to use warm-up before inference.") help="Flag to indicate whether to use warm-up before inference.",
)
model_group.add_argument( model_group.add_argument(
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
default=EngineArgs.limit_mm_per_prompt, default=EngineArgs.limit_mm_per_prompt,
type=json.loads, type=json.loads,
help="Limitation of numbers of multi-modal data.") help="Limitation of numbers of multi-modal data.",
)
model_group.add_argument( model_group.add_argument(
"--mm-processor-kwargs", "--mm-processor-kwargs",
default=EngineArgs.mm_processor_kwargs, default=EngineArgs.mm_processor_kwargs,
type=json.loads, type=json.loads,
help="Additional keyword arguments for the multi-modal processor.") help="Additional keyword arguments for the multi-modal processor.",
model_group.add_argument("--enable-mm", )
action='store_true', model_group.add_argument(
"--enable-mm",
action="store_true",
default=EngineArgs.enable_mm, default=EngineArgs.enable_mm,
help="Flag to enable multi-modal model.") help="Flag to enable multi-modal model.",
model_group.add_argument("--reasoning-parser", )
model_group.add_argument(
"--reasoning-parser",
type=str, type=str,
default=EngineArgs.reasoning_parser, default=EngineArgs.reasoning_parser,
help="Flag specifies the reasoning parser to use for extracting "\ help="Flag specifies the reasoning parser to use for extracting "
"reasoning content from the model output") "reasoning content from the model output",
)
model_group.add_argument( model_group.add_argument(
"--speculative-config", "--speculative-config",
type=json.loads, type=json.loads,
default=EngineArgs.speculative_config, default=EngineArgs.speculative_config,
help="Configuration for speculative execution.") help="Configuration for speculative execution.",
)
model_group.add_argument( model_group.add_argument(
"--dynamic-load-weight", "--dynamic-load-weight",
action='store_true', action="store_true",
default=EngineArgs.dynamic_load_weight, default=EngineArgs.dynamic_load_weight,
help="Flag to indicate whether to load weight dynamically.") help="Flag to indicate whether to load weight dynamically.",
)
model_group.add_argument( model_group.add_argument(
"--load-strategy", "--load-strategy",
type=str, type=str,
default=EngineArgs.load_strategy, default=EngineArgs.load_strategy,
help="Flag to dynamic load strategy.") help="Flag to dynamic load strategy.",
model_group.add_argument("--engine-worker-queue-port", )
model_group.add_argument(
"--engine-worker-queue-port",
type=int, type=int,
default=EngineArgs.engine_worker_queue_port, default=EngineArgs.engine_worker_queue_port,
help="port for engine worker queue") help="port for engine worker queue",
model_group.add_argument("--quantization", )
model_group.add_argument(
"--quantization",
type=str, type=str,
default=EngineArgs.quantization, default=EngineArgs.quantization,
help="Quantization name for the model, currentlly support " \ help="Quantization name for the model, currentlly support "
"'wint8', 'wint4'," \ "'wint8', 'wint4',"
"default is None. The priority of this configuration "\ "default is None. The priority of this configuration "
"is lower than that of the config file. " \ "is lower than that of the config file. "
"More complex quantization methods need to be configured via the config file.") "More complex quantization methods need to be configured via the config file.",
model_group.add_argument("--use-cudagraph", )
action='store_true', model_group.add_argument(
"--use-cudagraph",
action="store_true",
default=EngineArgs.use_cudagraph, default=EngineArgs.use_cudagraph,
help="Flags to enable cuda graph.") help="Flags to enable cuda graph.",
model_group.add_argument("--graph-optimization-config", )
model_group.add_argument(
"--graph-optimization-config",
type=json.loads, type=json.loads,
default=EngineArgs.graph_optimization_config, default=EngineArgs.graph_optimization_config,
help="") help="",
model_group.add_argument("--guided-decoding-backend", )
model_group.add_argument(
"--guided-decoding-backend",
type=str, type=str,
default=EngineArgs.guided_decoding_backend, default=EngineArgs.guided_decoding_backend,
help="Guided Decoding Backend") help="Guided Decoding Backend",
)
model_group.add_argument( model_group.add_argument(
"--guided-decoding-disable-any-whitespace", "--guided-decoding-disable-any-whitespace",
type=str, type=str,
default=EngineArgs.guided_decoding_disable_any_whitespace, default=EngineArgs.guided_decoding_disable_any_whitespace,
help= help="Disabled any whitespaces when using guided decoding backend XGrammar.",
"Disabled any whitespaces when using guided decoding backend XGrammar."
) )
model_group.add_argument("--enable-logprob", model_group.add_argument(
"--enable-logprob",
action="store_true", action="store_true",
default=EngineArgs.enable_logprob, default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities." help="Enable output of token-level log probabilities.",
) )
# Parallel processing parameters group # Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration") parallel_group = parser.add_argument_group("Parallel Configuration")
parallel_group.add_argument("--tensor-parallel-size", parallel_group.add_argument(
"--tensor-parallel-size",
"-tp", "-tp",
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help="Degree of tensor parallelism.") help="Degree of tensor parallelism.",
parallel_group.add_argument("--enable-custom-all-reduce", )
action='store_true', parallel_group.add_argument(
"--enable-custom-all-reduce",
action="store_true",
default=EngineArgs.enable_custom_all_reduce, default=EngineArgs.enable_custom_all_reduce,
help="Flag to enable custom all-reduce.") help="Flag to enable custom all-reduce.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--max-num-seqs", "--max-num-seqs",
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help="Maximum number of sequences per iteration.") help="Maximum number of sequences per iteration.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--num-gpu-blocks-override", "--num-gpu-blocks-override",
type=int, type=int,
default=EngineArgs.num_gpu_blocks_override, default=EngineArgs.num_gpu_blocks_override,
help="Override for the number of GPU blocks.") help="Override for the number of GPU blocks.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--max-num-batched-tokens", "--max-num-batched-tokens",
type=int, type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help="Maximum number of tokens to batch together.") help="Maximum number of tokens to batch together.",
)
parallel_group.add_argument( parallel_group.add_argument(
"--gpu-memory-utilization", "--gpu-memory-utilization",
type=float, type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help="Fraction of GPU memory to be utilized.") help="Fraction of GPU memory to be utilized.",
)
parallel_group.add_argument("--data-parallel-size", parallel_group.add_argument(
"--data-parallel-size",
type=int, type=int,
default=EngineArgs.data_parallel_size, default=EngineArgs.data_parallel_size,
help="Degree of data parallelism.") help="Degree of data parallelism.",
parallel_group.add_argument("--enable-expert-parallel", )
action='store_true', parallel_group.add_argument(
"--enable-expert-parallel",
action="store_true",
default=EngineArgs.enable_expert_parallel, default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.") help="Enable expert parallelism.",
)
# CacheConfig parameters group # CacheConfig parameters group
cache_group = parser.add_argument_group("Cache Configuration") cache_group = parser.add_argument_group("Cache Configuration")
cache_group.add_argument("--kv-cache-ratio", cache_group.add_argument(
"--kv-cache-ratio",
type=float, type=float,
default=EngineArgs.kv_cache_ratio, default=EngineArgs.kv_cache_ratio,
help="Ratio of tokens to process in a block.") help="Ratio of tokens to process in a block.",
)
cache_group.add_argument( cache_group.add_argument(
"--swap-space", "--swap-space",
type=float, type=float,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help="The amount of CPU memory to offload to.") help="The amount of CPU memory to offload to.",
)
cache_group.add_argument("--cache-queue-port", cache_group.add_argument(
"--cache-queue-port",
type=int, type=int,
default=EngineArgs.cache_queue_port, default=EngineArgs.cache_queue_port,
help="port for cache queue") help="port for cache queue",
cache_group.add_argument("--static-decode-blocks", )
cache_group.add_argument(
"--static-decode-blocks",
type=int, type=int,
default=EngineArgs.static_decode_blocks, default=EngineArgs.static_decode_blocks,
help="Static decoding blocks num.") help="Static decoding blocks num.",
)
# Cluster system parameters group # Cluster system parameters group
system_group = parser.add_argument_group("System Configuration") system_group = parser.add_argument_group("System Configuration")
system_group.add_argument( system_group.add_argument(
"--dist-init-ip", "--dist-init-ip",
default=EngineArgs.dist_init_ip, default=EngineArgs.dist_init_ip,
help= help="IP addresses of master node.",
"IP addresses of master node.") )
system_group.add_argument( system_group.add_argument(
"--nnodes", "--nnodes",
type=int, type=int,
default=EngineArgs.nnodes, default=EngineArgs.nnodes,
help= help="The number of all nodes.",
"The number of all nodes.") )
system_group.add_argument( system_group.add_argument(
"--node-rank", "--node-rank",
type=int, type=int,
default=EngineArgs.node_rank, default=EngineArgs.node_rank,
help= help="node rank id (range [0, nnodes)).",
"node rank id (range [0, nnodes)).") )
# Performance tuning parameters group # Performance tuning parameters group
perf_group = parser.add_argument_group("Performance Tuning") perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument("--enable-prefix-caching", perf_group.add_argument(
action='store_true', "--enable-prefix-caching",
action="store_true",
default=EngineArgs.enable_prefix_caching, default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.") help="Flag to enable prefix caching.",
)
perf_group.add_argument("--splitwise-role", perf_group.add_argument(
"--splitwise-role",
type=str, type=str,
default=EngineArgs.splitwise_role, default=EngineArgs.splitwise_role,
help="Role of splitwise. Default is \ help="Role of splitwise. Default is \
'mixed'. (prefill, decode, mixed)") 'mixed'. (prefill, decode, mixed)",
)
perf_group.add_argument("--innode-prefill-ports", perf_group.add_argument(
"--innode-prefill-ports",
type=lambda s: s.split(",") if s else None, type=lambda s: s.split(",") if s else None,
default=EngineArgs.innode_prefill_ports, default=EngineArgs.innode_prefill_ports,
help="port for innode prefill") help="port for innode prefill",
)
perf_group.add_argument("--enable-chunked-prefill", perf_group.add_argument(
action='store_true', "--enable-chunked-prefill",
action="store_true",
default=EngineArgs.enable_chunked_prefill, default=EngineArgs.enable_chunked_prefill,
help="Flag to enable chunked prefill.") help="Flag to enable chunked prefill.",
perf_group.add_argument("--max-num-partial-prefills", )
perf_group.add_argument(
"--max-num-partial-prefills",
type=int, type=int,
default=EngineArgs.max_num_partial_prefills, default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, Maximum number \ help="For chunked prefill, Maximum number \
of concurrent partial prefill requests.") of concurrent partial prefill requests.",
)
perf_group.add_argument( perf_group.add_argument(
"--max-long-partial-prefills", "--max-long-partial-prefills",
type=int, type=int,
default=EngineArgs.max_long_partial_prefills, default=EngineArgs.max_long_partial_prefills,
help= help=(
("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold" "For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
"that will be prefilled concurrently.")) "that will be prefilled concurrently."
),
)
perf_group.add_argument( perf_group.add_argument(
"--long-prefill-token-threshold", "--long-prefill-token-threshold",
type=int, type=int,
default=EngineArgs.long_prefill_token_threshold, default=EngineArgs.long_prefill_token_threshold,
help=("For chunked prefill, the threshold number of" help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long."),
" tokens for a prompt to be considered long.")) )
perf_group.add_argument( perf_group.add_argument(
"--cache-transfer-protocol", "--cache-transfer-protocol",
type=str, type=str,
default=EngineArgs.cache_transfer_protocol, default=EngineArgs.cache_transfer_protocol,
help="support protocol list, comma separated, default is ipc") help="support protocol list, comma separated, default is ipc",
)
perf_group.add_argument("--pd-comm-port", perf_group.add_argument(
"--pd-comm-port",
type=lambda s: s.split(",") if s else None, type=lambda s: s.split(",") if s else None,
default=EngineArgs.pd_comm_port, default=EngineArgs.pd_comm_port,
help="port for splitwise communication.") help="port for splitwise communication.",
)
perf_group.add_argument("--rdma-comm-ports", perf_group.add_argument(
"--rdma-comm-ports",
type=lambda s: s.split(",") if s else None, type=lambda s: s.split(",") if s else None,
default=EngineArgs.rdma_comm_ports, default=EngineArgs.rdma_comm_ports,
help="ports for rdma communication.") help="ports for rdma communication.",
)
# Scheduler parameters group # Scheduler parameters group
scheduler_group = parser.add_argument_group("Scheduler") scheduler_group = parser.add_argument_group("Scheduler")
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-name", "--scheduler-name",
default=EngineArgs.scheduler_name, default=EngineArgs.scheduler_name,
help= help=f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)",
f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-max-size", "--scheduler-max-size",
type=int, type=int,
default=EngineArgs.scheduler_max_size, default=EngineArgs.scheduler_max_size,
help= help=f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)",
f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-ttl", "--scheduler-ttl",
type=int, type=int,
default=EngineArgs.scheduler_ttl, default=EngineArgs.scheduler_ttl,
help= help=f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)",
f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-host", "--scheduler-host",
default=EngineArgs.scheduler_host, default=EngineArgs.scheduler_host,
help= help=f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)",
f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-port", "--scheduler-port",
type=int, type=int,
default=EngineArgs.scheduler_port, default=EngineArgs.scheduler_port,
help= help=f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)",
f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)") )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-db", "--scheduler-db",
type=int, type=int,
default=EngineArgs.scheduler_db, default=EngineArgs.scheduler_db,
help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)" help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)",
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-password", "--scheduler-password",
default=EngineArgs.scheduler_password, default=EngineArgs.scheduler_password,
help= help=f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)",
f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-topic", "--scheduler-topic",
default=EngineArgs.scheduler_topic, default=EngineArgs.scheduler_topic,
help= help=f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)",
f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-min-load-score", "--scheduler-min-load-score",
type=float, type=float,
default=EngineArgs.scheduler_min_load_score, default=EngineArgs.scheduler_min_load_score,
help= help=f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)",
f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
) )
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-load-shards-num", "--scheduler-load-shards-num",
type=int, type=int,
default=EngineArgs.scheduler_load_shards_num, default=EngineArgs.scheduler_load_shards_num,
help=("Number of shards for load balancing table. Default is " help=(
f"{EngineArgs.scheduler_load_shards_num} (global)")) "Number of shards for load balancing table. Default is "
f"{EngineArgs.scheduler_load_shards_num} (global)"
),
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-sync-period", "--scheduler-sync-period",
type=int, type=int,
default=EngineArgs.scheduler_sync_period, default=EngineArgs.scheduler_sync_period,
help=f"SplitWise Use, node load sync period, " help=f"SplitWise Use, node load sync period, "
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)") f"Default is {EngineArgs.scheduler_sync_period}ms. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-expire-period", "--scheduler-expire-period",
type=int, type=int,
default=EngineArgs.scheduler_expire_period, default=EngineArgs.scheduler_expire_period,
help=f"SplitWise Use, node will not be scheduled after " help=f"SplitWise Use, node will not be scheduled after "
f"expire-period ms not sync load, Default is " f"expire-period ms not sync load, Default is "
f"{EngineArgs.scheduler_expire_period}ms. (global)") f"{EngineArgs.scheduler_expire_period}ms. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-release-load-expire-period", "--scheduler-release-load-expire-period",
type=int, type=int,
default=EngineArgs.scheduler_release_load_expire_period, default=EngineArgs.scheduler_release_load_expire_period,
help=f"SplitWise Use, scheduler will release req load after " help=f"SplitWise Use, scheduler will release req load after "
f"expire period(s). Default is " f"expire period(s). Default is "
f"{EngineArgs.scheduler_release_load_expire_period}. (global)") f"{EngineArgs.scheduler_release_load_expire_period}. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-reader-parallel", "--scheduler-reader-parallel",
type=int, type=int,
default=EngineArgs.scheduler_reader_parallel, default=EngineArgs.scheduler_reader_parallel,
help=f"SplitWise Use, Results Reader Sync Parallel, " help=f"SplitWise Use, Results Reader Sync Parallel, "
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)") f"Default is {EngineArgs.scheduler_reader_parallel}. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-writer-parallel", "--scheduler-writer-parallel",
type=int, type=int,
default=EngineArgs.scheduler_writer_parallel, default=EngineArgs.scheduler_writer_parallel,
help=f"SplitWise Use, Results Writer Sync Parallel, " help=f"SplitWise Use, Results Writer Sync Parallel, "
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)") f"Default is {EngineArgs.scheduler_writer_parallel}. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-reader-batch-size", "--scheduler-reader-batch-size",
type=int, type=int,
default=EngineArgs.scheduler_reader_batch_size, default=EngineArgs.scheduler_reader_batch_size,
help=f"SplitWise Use, Results Reader Batch Size, " help=f"SplitWise Use, Results Reader Batch Size, "
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)") f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)",
)
scheduler_group.add_argument( scheduler_group.add_argument(
"--scheduler-writer-batch-size", "--scheduler-writer-batch-size",
type=int, type=int,
default=EngineArgs.scheduler_writer_batch_size, default=EngineArgs.scheduler_writer_batch_size,
help=f"SplitWise Use, Results Writer Batch Size, " help=f"SplitWise Use, Results Writer Batch Size, "
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)") f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)",
)
return parser return parser
@@ -690,21 +760,19 @@ class EngineArgs:
""" """
Create an instance of EngineArgs from command line arguments. Create an instance of EngineArgs from command line arguments.
""" """
return cls( return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)})
**{
field.name: getattr(args, field.name)
for field in dataclass_fields(cls)
})
def create_model_config(self) -> ModelConfig: def create_model_config(self) -> ModelConfig:
""" """
Create and return a ModelConfig object based on the current settings. Create and return a ModelConfig object based on the current settings.
""" """
return ModelConfig(model_name_or_path=self.model, return ModelConfig(
model_name_or_path=self.model,
config_json_file=self.model_config_name, config_json_file=self.model_config_name,
quantization=self.quantization, quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight, dynamic_load_weight=self.dynamic_load_weight,
load_strategy=self.load_strategy) load_strategy=self.load_strategy,
)
def create_cache_config(self, model_cfg) -> CacheConfig: def create_cache_config(self, model_cfg) -> CacheConfig:
""" """
@@ -728,8 +796,7 @@ class EngineArgs:
) )
def create_speculative_config(self) -> SpeculativeConfig: def create_speculative_config(self) -> SpeculativeConfig:
""" """ """
"""
if self.speculative_config is not None: if self.speculative_config is not None:
return SpeculativeConfig(**self.speculative_config) return SpeculativeConfig(**self.speculative_config)
else: else:
@@ -742,9 +809,11 @@ class EngineArgs:
prefix = "scheduler_" prefix = "scheduler_"
prefix_len = len(prefix) prefix_len = len(prefix)
extra_params = [ extra_params = [
"max_model_len", "enable_chunked_prefill", "max_model_len",
"max_num_partial_prefills", "max_long_partial_prefills", "enable_chunked_prefill",
"long_prefill_token_threshold" "max_num_partial_prefills",
"max_long_partial_prefills",
"long_prefill_token_threshold",
] ]
all = asdict(self) all = asdict(self)
@@ -765,7 +834,7 @@ class EngineArgs:
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
data_parallel_size=self.data_parallel_size, data_parallel_size=self.data_parallel_size,
enable_custom_all_reduce=self.enable_custom_all_reduce enable_custom_all_reduce=self.enable_custom_all_reduce,
) )
def create_graph_optimization_config(self) -> GraphOptimizationConfig: def create_graph_optimization_config(self) -> GraphOptimizationConfig:
@@ -782,8 +851,7 @@ class EngineArgs:
Create and return a Config object based on the current settings. Create and return a Config object based on the current settings.
""" """
model_cfg = self.create_model_config() model_cfg = self.create_model_config()
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
'tensor_parallel_size'):
self.tensor_parallel_size = model_cfg.tensor_parallel_size self.tensor_parallel_size = model_cfg.tensor_parallel_size
if self.max_num_batched_tokens is None: if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill: if self.enable_chunked_prefill:
@@ -795,11 +863,11 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
assert not (self.use_cudagraph and self.enable_prefix_caching), \ assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"
"Prefix caching cannot be used with CUDA graph"
assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \ assert not (
"enable_custom_all_reduce must be used with tensor_parallel_size>1" self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
return Config( return Config(
model_name_or_path=self.model, model_name_or_path=self.model,

View File

@@ -23,8 +23,14 @@ from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip, from fastdeploy.utils import (
is_port_available, get_random_port, llm_logger) ceil_div,
check_unified_ckpt,
get_host_ip,
get_random_port,
is_port_available,
llm_logger,
)
TaskOption = Literal["generate"] TaskOption = Literal["generate"]
@@ -39,13 +45,15 @@ class ModelConfig:
model_name_or_path (str): Name or path of the model. model_name_or_path (str): Name or path of the model.
""" """
def __init__(self, def __init__(
self,
model_name_or_path: str, model_name_or_path: str,
config_json_file: str = "config.json", config_json_file: str = "config.json",
dynamic_load_weight: bool = False, dynamic_load_weight: bool = False,
load_strategy: str = "ipc_snapshot", load_strategy: str = "ipc_snapshot",
quantization: str = None, quantization: str = None,
download_dir: Optional[str] = None): download_dir: Optional[str] = None,
):
""" """
Initialize the ModelConfig class. Initialize the ModelConfig class.
@@ -64,11 +72,9 @@ class ModelConfig:
if os.path.isfile(model_name_or_path): if os.path.isfile(model_name_or_path):
try: try:
from paddleformers.transformers import AutoConfig from paddleformers.transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path)
config_dict = { config_dict = {k: v for k, v in vars(config).items() if not k.startswith("_")}
k: v
for k, v in vars(config).items() if not k.startswith('_')
}
for key, value in config_dict.items(): for key, value in config_dict.items():
setattr(self, key, value) setattr(self, key, value)
except Exception: except Exception:
@@ -115,8 +121,7 @@ class ModelConfig:
if not hasattr(self, "mla_use_absorb"): if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False self.mla_use_absorb = False
if not hasattr(self, "head_dim"): if not hasattr(self, "head_dim"):
assert hasattr(self, "hidden_size") and hasattr( assert hasattr(self, "hidden_size") and hasattr(self, "num_attention_heads")
self, "num_attention_heads")
self.head_dim = self.hidden_size // self.num_attention_heads self.head_dim = self.hidden_size // self.num_attention_heads
def read_from_env(self): def read_from_env(self):
@@ -132,11 +137,9 @@ class ModelConfig:
if not hasattr(self, key.lower()): if not hasattr(self, key.lower()):
if os.getenv(key, None): if os.getenv(key, None):
value = eval(os.getenv(key)) value = eval(os.getenv(key))
llm_logger.info( llm_logger.info(f"Get parameter `{key}` = {value} from environment.")
f"Get parameter `{key}` = {value} from environment.")
else: else:
llm_logger.info( llm_logger.info(f"Parameter `{key}` will use default value {value}.")
f"Parameter `{key}` will use default value {value}.")
setattr(self, key.lower(), value) setattr(self, key.lower(), value)
reset_config_value("COMPRESSION_RATIO", 1.0) reset_config_value("COMPRESSION_RATIO", 1.0)
@@ -153,8 +156,7 @@ class ModelConfig:
llm_logger.info("Model Configuration Information :") llm_logger.info("Model Configuration Information :")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
class CacheConfig: class CacheConfig:
@@ -211,8 +213,7 @@ class CacheConfig:
self.enc_dec_block_num = enc_dec_block_num self.enc_dec_block_num = enc_dec_block_num
self.cache_dtype = cache_dtype self.cache_dtype = cache_dtype
if hasattr(model_cfg, "quantization_config"): if hasattr(model_cfg, "quantization_config"):
self.cache_dtype = model_cfg.quantization_config.get( self.cache_dtype = model_cfg.quantization_config.get("kv_cache_quant_type", cache_dtype)
"kv_cache_quant_type", cache_dtype)
self.enable_chunked_prefill = enable_chunked_prefill self.enable_chunked_prefill = enable_chunked_prefill
self.rdma_comm_ports = rdma_comm_ports self.rdma_comm_ports = rdma_comm_ports
@@ -220,7 +221,7 @@ class CacheConfig:
self.pd_comm_port = pd_comm_port self.pd_comm_port = pd_comm_port
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str): if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
self.rdma_comm_ports = rdma_comm_ports.split(',') self.rdma_comm_ports = rdma_comm_ports.split(",")
if pd_comm_port is not None and isinstance(pd_comm_port, str): if pd_comm_port is not None and isinstance(pd_comm_port, str):
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")] self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
@@ -236,41 +237,39 @@ class CacheConfig:
self.cache_queue_port = cache_queue_port self.cache_queue_port = cache_queue_port
self.swap_space = swap_space self.swap_space = swap_space
if (hasattr(self.model_cfg, "num_key_value_heads") if (
hasattr(self.model_cfg, "num_key_value_heads")
and hasattr(self.model_cfg, "num_key_value_heads") and hasattr(self.model_cfg, "num_key_value_heads")
and self.model_cfg.num_key_value_heads is not None and self.model_cfg.num_key_value_heads is not None
and int(self.model_cfg.num_key_value_heads) > 0): and int(self.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(self.model_cfg.num_key_value_heads) kv_num_head = int(self.model_cfg.num_key_value_heads)
else: else:
kv_num_head = self.model_cfg.num_attention_heads kv_num_head = self.model_cfg.num_attention_heads
self.model_cfg.kv_num_head = kv_num_head self.model_cfg.kv_num_head = kv_num_head
# TODO check name # TODO check name
if "int4" in self.cache_dtype.lower( if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower():
) or "float4" in self.cache_dtype.lower():
byte_size = 0.5 byte_size = 0.5
self.cache_dtype = "uint8" self.cache_dtype = "uint8"
elif "int8" in self.cache_dtype.lower( elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower():
) or "float8" in self.cache_dtype.lower():
self.cache_dtype = "uint8" self.cache_dtype = "uint8"
byte_size = 1 byte_size = 1
else: else:
byte_size = 2 byte_size = 2
self.each_token_cache_space = int( self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size
byte_size) )
self.bytes_per_block = int(self.each_token_cache_space * self.bytes_per_block = int(self.each_token_cache_space * self.block_size)
self.block_size)
self.bytes_per_layer_per_block = int( self.bytes_per_layer_per_block = int(
self.block_size * self.model_cfg.kv_num_head * self.block_size * self.model_cfg.kv_num_head * self.model_cfg.head_dim // tensor_parallel_size * byte_size
self.model_cfg.head_dim // tensor_parallel_size * byte_size) )
if self.swap_space is None: if self.swap_space is None:
self.num_cpu_blocks = 0 self.num_cpu_blocks = 0
else: else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self.bytes_per_block)
self._verify_args() self._verify_args()
def metrics_info(self): def metrics_info(self):
@@ -279,12 +278,9 @@ class CacheConfig:
def _verify_args(self): def _verify_args(self):
if self.gpu_memory_utilization > 1.0: if self.gpu_memory_utilization > 1.0:
raise ValueError( raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_cache_ratio > 1.0: if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got " raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
f"{self.kv_cache_ratio}.")
def postprocess(self, num_total_tokens, number_of_tasks): def postprocess(self, num_total_tokens, number_of_tasks):
""" """
@@ -293,27 +289,24 @@ class CacheConfig:
self.dec_token_num = self.enc_dec_block_num * self.block_size self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None: if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override self.total_block_num = self.num_gpu_blocks_override
self.prefill_kvcache_block_num = int(self.total_block_num * self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
self.kv_cache_ratio)
else: else:
length = num_total_tokens // number_of_tasks length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num self.prefill_kvcache_block_num = self.total_block_num
llm_logger.info( llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks): def reset(self, num_gpu_blocks):
""" """
reset gpu block number reset gpu block number
""" """
self.total_block_num = num_gpu_blocks self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num = int(self.total_block_num * self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
self.kv_cache_ratio)
llm_logger.info( llm_logger.info(
(f"Reset block num, the total_block_num:{self.total_block_num}," f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}")) f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
def print(self): def print(self):
""" """
@@ -323,8 +316,7 @@ class CacheConfig:
llm_logger.info("Cache Configuration Information :") llm_logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
class SpeculativeConfig: class SpeculativeConfig:
@@ -340,14 +332,16 @@ class SpeculativeConfig:
benchmark_mode (bool): Whether to use benchmark mode. benchmark_mode (bool): Whether to use benchmark mode.
""" """
def __init__(self, def __init__(
self,
method: Optional[str] = None, method: Optional[str] = None,
num_speculative_tokens: Optional[int] = 1, num_speculative_tokens: Optional[int] = 1,
model: Optional[str] = None, model: Optional[str] = None,
quantization: Optional[str] = "WINT8", quantization: Optional[str] = "WINT8",
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
benchmark_mode: bool = False, benchmark_mode: bool = False,
**kwargs): **kwargs,
):
self.model_name_or_path = model self.model_name_or_path = model
self.method = method self.method = method
self.num_speculative_tokens = num_speculative_tokens self.num_speculative_tokens = num_speculative_tokens
@@ -381,8 +375,7 @@ class SpeculativeConfig:
self.config_path = os.path.join(self.model_name_or_path, "config.json") self.config_path = os.path.join(self.model_name_or_path, "config.json")
if os.path.exists(self.config_path): if os.path.exists(self.config_path):
self.model_config = json.load( self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
open(self.config_path, 'r', encoding='utf-8'))
def reset(self): def reset(self):
""" """
@@ -414,10 +407,7 @@ class SpeculativeConfig:
""" """
Convert speculative_config to json string. Convert speculative_config to json string.
""" """
return json.dumps({ return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
key: value
for key, value in self.__dict__.items() if value is not None
})
def print(self): def print(self):
""" """
@@ -427,8 +417,7 @@ class SpeculativeConfig:
llm_logger.info("Speculative Decoding Configuration Information :") llm_logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
def __str__(self) -> str: def __str__(self) -> str:
return self.to_json_string() return self.to_json_string()
@@ -440,7 +429,7 @@ class GraphOptimizationConfig:
graph_opt_level: Optional[int] = 0, graph_opt_level: Optional[int] = 0,
use_cudagraph: Optional[bool] = None, use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None, cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs **kwargs,
): ):
""" """
Graph Optimization Configuration class. Graph Optimization Configuration class.
@@ -460,10 +449,7 @@ class GraphOptimizationConfig:
""" """
Convert speculative_config to json string. Convert speculative_config to json string.
""" """
return json.dumps({ return json.dumps({key: value for key, value in self.__dict__.items()})
key: value
for key, value in self.__dict__.items()
})
def __str__(self) -> str: def __str__(self) -> str:
return self.to_json_string() return self.to_json_string()
@@ -473,17 +459,25 @@ class GraphOptimizationConfig:
graph_opt_level: Optional[int] = None, graph_opt_level: Optional[int] = None,
use_cudagraph: Optional[bool] = None, use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None, cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs **kwargs,
) -> None: ) -> None:
"""Check the legality of parameters passed in from the command line""" """Check the legality of parameters passed in from the command line"""
if graph_opt_level is not None: if graph_opt_level is not None:
assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2." assert graph_opt_level in [
0,
1,
2,
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if use_cudagraph is not None: if use_cudagraph is not None:
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool." assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
if cudagraph_capture_sizes is not None: if cudagraph_capture_sizes is not None:
assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list." assert (
assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list." type(cudagraph_capture_sizes) is list
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert (
len(cudagraph_capture_sizes) > 0
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
for key, value in kwargs.items(): for key, value in kwargs.items():
raise ValueError(f"Invalid --graph-optimization-config parameter {key}") raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
@@ -499,9 +493,12 @@ class GraphOptimizationConfig:
else: else:
# User both set '--use-cudagraph' and '--graph-optimization-config' # User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True: if self.use_cudagraph is False and argument is True:
raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.") raise ValueError(
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
)
argument = self.use_cudagraph argument = self.use_cudagraph
class ParallelConfig: class ParallelConfig:
""" """
Configuration for parallelism. Configuration for parallelism.
@@ -544,8 +541,7 @@ class ParallelConfig:
llm_logger.info("Parallel Configuration Information :") llm_logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
@dataclass @dataclass
@@ -560,6 +556,7 @@ class CommitConfig:
cuda_version: CUDA version string cuda_version: CUDA version string
compiler_version: CXX compiler version string compiler_version: CXX compiler version string
""" """
fastdeploy_commit: str = "" fastdeploy_commit: str = ""
paddle_version: str = "" paddle_version: str = ""
paddle_commit: str = "" paddle_commit: str = ""
@@ -573,7 +570,7 @@ class CommitConfig:
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"): def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
"""Internal method to load version info from file""" """Internal method to load version info from file"""
try: try:
with open(file_path, 'r') as f: with open(file_path, "r") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"): if line.startswith("fastdeploy GIT COMMIT ID:"):
@@ -589,7 +586,7 @@ class CommitConfig:
except FileNotFoundError: except FileNotFoundError:
llm_logger.info(f"Warning: Version file not found at {file_path}") llm_logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e: except Exception as e:
llm_logger.info(f"Warning: Could not read version file - {str(e)}") llm_logger.info(f"Warning: Could not read version file - {e!s}")
def print(self): def print(self):
""" """
@@ -599,8 +596,7 @@ class CommitConfig:
llm_logger.info("Fasedeploy Commit Information :") llm_logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
class Config: class Config:
@@ -728,7 +724,6 @@ class Config:
self.disable_any_whitespace = disable_any_whitespace self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int) self._str_to_list("innode_prefill_ports", int)
assert self.splitwise_role in ["mixed", "prefill", "decode"] assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO # TODO
@@ -739,19 +734,16 @@ class Config:
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
# TODO(@wufeisheng): TP and EP need to be supported simultaneously. # TODO(@wufeisheng): TP and EP need to be supported simultaneously.
assert (self.tensor_parallel_size == 1 assert (self.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
and self.parallel_config.expert_parallel_size self.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
>= 1) or (self.tensor_parallel_size >= 1 ), "TP and EP cannot be enabled at the same time"
and self.parallel_config.expert_parallel_size
== 1), "TP and EP cannot be enabled at the same time"
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node: if num_ranks > self.max_chips_per_node:
self.worker_num_per_node = self.max_chips_per_node self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node) nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, \ assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
f"nnode: {nnode}, but got {self.nnode}"
else: else:
self.worker_num_per_node = num_ranks self.worker_num_per_node = num_ranks
@@ -772,13 +764,14 @@ class Config:
""" """
calculate some parameters calculate some parameters
""" """
assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \ assert (
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}" self.device_ids.split(",").__len__() == self.worker_num_per_node
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \ assert (
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}" self.worker_num_per_node % self.tensor_parallel_size == 0
self.local_device_ids = self.device_ids.split( ), f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
',')[:self.tensor_parallel_size] self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size]
self.host_ip = get_host_ip() self.host_ip = get_host_ip()
@@ -788,6 +781,7 @@ class Config:
self.is_master = False self.is_master = False
import paddle import paddle
self.paddle_commit_id = paddle.version.commit self.paddle_commit_id = paddle.version.commit
if self.max_num_batched_tokens is None: if self.max_num_batched_tokens is None:
@@ -799,10 +793,8 @@ class Config:
if self.long_prefill_token_threshold == 0: if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04) self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.cache_config.postprocess(self.max_num_batched_tokens, self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
self.max_num_seqs) self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
self.cache_config.max_block_num_per_seq = int(
self.max_model_len // self.cache_config.block_size)
if self.guided_decoding_backend == "auto": if self.guided_decoding_backend == "auto":
if self.enable_mm: if self.enable_mm:
@@ -814,30 +806,26 @@ class Config:
""" """
check the legality of config check the legality of config
""" """
assert ( assert self.max_num_seqs <= 256, (
self.max_num_seqs <= 256 "The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
), "The parameter `max_num_seqs` is not allowed to exceed 256, " "but now it's {}.".format( )
self.max_num_seqs) assert is_port_available(
assert ( "0.0.0.0", self.engine_worker_queue_port
is_port_available('0.0.0.0', self.engine_worker_queue_port)
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use." ), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
assert ( assert (
self.max_chips_per_node >= self.tensor_parallel_size > 0 self.max_chips_per_node >= self.tensor_parallel_size > 0
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}" ), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1" assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
assert ( assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
self.max_model_len >= 16 assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
), f"max_model_len: {self.max_model_len} should be larger than 16" assert self.max_num_batched_tokens >= self.max_num_seqs, (
assert ( f"max_num_batched_tokens: {self.max_num_batched_tokens} "
self.max_num_seqs
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
assert (
self.max_num_batched_tokens >= self.max_num_seqs
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}" f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
assert (self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs), \ )
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" \ assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}" f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
)
assert ( assert (
self.max_num_partial_prefills >= 1 self.max_num_partial_prefills >= 1
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1" ), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
@@ -845,31 +833,38 @@ class Config:
assert ( assert (
self.max_long_partial_prefills >= 1 self.max_long_partial_prefills >= 1
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1" ), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
assert (self.max_long_partial_prefills <= self.max_num_partial_prefills), \ assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " \ f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}" f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
)
if not self.cache_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
assert ( assert self.max_num_batched_tokens >= self.max_model_len, (
self.max_num_batched_tokens >= self.max_model_len f"max_num_batched_tokens: {self.max_num_batched_tokens} "
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
f"should be larger than or equal to max_model_len: {self.max_model_len}" f"should be larger than or equal to max_model_len: {self.max_model_len}"
)
else: else:
assert ( assert self.max_num_batched_tokens >= self.cache_config.block_size, (
self.max_num_batched_tokens >= self.cache_config.block_size f"max_num_batched_tokens: {self.max_num_batched_tokens} "
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
f"should be larger than or equal to block_size: {self.cache_config.block_size}" f"should be larger than or equal to block_size: {self.cache_config.block_size}"
)
if self.max_num_partial_prefills > 1: if self.max_num_partial_prefills > 1:
assert (self.cache_config.enable_chunked_prefill is True), \ assert (
"Chunked prefill must be enabled to set max_num_partial_prefills > 1" self.cache_config.enable_chunked_prefill is True
assert (self.long_prefill_token_threshold < self.max_model_len), \ ), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"\ assert self.long_prefill_token_threshold < self.max_model_len, (
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
f" max_model_len: {self.max_model_len}" f" max_model_len: {self.max_model_len}"
)
if self.guided_decoding_backend is not None: if self.guided_decoding_backend is not None:
assert self.guided_decoding_backend in ["xgrammar", "XGrammar", "auto", "off"], \ assert self.guided_decoding_backend in [
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." "xgrammar",
"XGrammar",
"auto",
"off",
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
if self.guided_decoding_backend != "off": if self.guided_decoding_backend != "off":
# TODO: mm support guided_decoding # TODO: mm support guided_decoding
@@ -878,8 +873,7 @@ class Config:
# TODO: speculative decoding support guided_decoding # TODO: speculative decoding support guided_decoding
# TODO: xpu support guided_decoding # TODO: xpu support guided_decoding
assert not current_platform.is_xpu( assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
), "XPU currently do not support guided_decoding"
try: try:
import xgrammar # noqa import xgrammar # noqa
@@ -897,22 +891,22 @@ class Config:
Args: Args:
file (str): the path of file to save config file (str): the path of file to save config
""" """
llm_logger.info( llm_logger.info("=================== Configuration Information ===============")
"=================== Configuration Information ===============")
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "generation_config" and v is not None: if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items(): for gck, gcv in v.to_dict().items():
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv)) llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif (k == "cache_config" or elif (
k == "model_config" or k == "cache_config"
k == "scheduler_config" or or k == "model_config"
k == "parallel_config" or or k == "scheduler_config"
k == "commit_config"): or k == "parallel_config"
or k == "commit_config"
):
v.print() v.print()
else: else:
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info( llm_logger.info("=============================================================")
"=============================================================")
if file is not None: if file is not None:
f = open(file, "a") f = open(file, "a")
now_time = datetime.now() now_time = datetime.now()
@@ -929,15 +923,14 @@ class Config:
if self.splitwise_role != "mixed": if self.splitwise_role != "mixed":
disaggregate_info["role"] = self.splitwise_role disaggregate_info["role"] = self.splitwise_role
disaggregate_info["cache_info"] = dict() disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split( current_protocol = self.cache_config.cache_transfer_protocol.split(",")
",")
disaggregate_info["transfer_protocol"] = current_protocol disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol: for protocol in current_protocol:
if protocol == "ipc": if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = { disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip, "ip": self.host_ip,
"port": self.engine_worker_queue_port, "port": self.engine_worker_queue_port,
"device_ids": self.local_device_ids "device_ids": self.local_device_ids,
} }
elif protocol == "rdma": elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = { disaggregate_info["cache_info"][protocol] = {
@@ -957,13 +950,14 @@ class Config:
if hasattr(cls, key): if hasattr(cls, key):
value = getattr(cls, key) value = getattr(cls, key)
setattr(cls, value_name, value) setattr(cls, value_name, value)
llm_logger.info( llm_logger.info(f"Reset parameter {value_name} = {value} from configuration.")
f"Reset parameter {value_name} = {value} from configuration."
)
reset_value(self.cache_config, "block_size", "infer_model_block_size") reset_value(self.cache_config, "block_size", "infer_model_block_size")
reset_value(self.model_config, "return_full_hidden_states", reset_value(
"return_full_hidden_states") self.model_config,
"return_full_hidden_states",
"return_full_hidden_states",
)
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def _check_master(self): def _check_master(self):

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
from __future__ import annotations from __future__ import annotations
import copy import copy
@@ -40,18 +41,21 @@ from fastdeploy.engine.expert_service import start_expert_service
from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue, from fastdeploy.inter_communicator import (
IPCSignal, ZmqClient) EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
)
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import (TokenProcessor, from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
WarmUpTokenProcessor)
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger from fastdeploy.utils import EngineError, console_logger, llm_logger
class LLMEngine(object): class LLMEngine:
""" """
Engine class responsible for managing the Large Language Model (LLM) operations. Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -94,30 +98,28 @@ class LLMEngine(object):
self.running = True self.running = True
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
self.input_processor = InputPreprocessor(cfg.tokenizer, self.input_processor = InputPreprocessor(
cfg.tokenizer,
cfg.reasoning_parser, cfg.reasoning_parser,
cfg.limit_mm_per_prompt, cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs, cfg.mm_processor_kwargs,
cfg.enable_mm) cfg.enable_mm,
)
self.start_queue_service() self.start_queue_service()
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role)
cfg.tensor_parallel_size,
cfg.splitwise_role)
os.environ['INFERENCE_MSG_QUEUE_ID'] = str( os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
self.cfg.engine_worker_queue_port)
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
self.engine_worker_queue,
self.resource_manager)
self.token_processor = TokenProcessor( self.token_processor = TokenProcessor(
cfg=self.cfg, cfg=self.cfg,
cached_generated_tokens=self.scheduler, cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue, engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector) split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager) self.token_processor.set_resource_manager(self.resource_manager)
self.is_started = False self.is_started = False
@@ -129,11 +131,13 @@ class LLMEngine(object):
else: else:
self.do_profile = 0 self.do_profile = 0
self.partial_chunked_tokens = [0] * ( self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1): for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \ self.partial_chunked_tokens[idx] = (
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size (self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.partial_chunked_tokens[idx] = max(1, self.partial_chunked_tokens[idx]) self.partial_chunked_tokens[idx] = max(1, self.partial_chunked_tokens[idx])
self._finalizer = weakref.finalize(self, self._exit_sub_services) self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -168,8 +172,8 @@ class LLMEngine(object):
time.sleep(3) time.sleep(3)
if self.do_profile == 0 and ( if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching \ self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
or self.cfg.splitwise_role != "mixed"): ):
device_ids = self.cfg.device_ids.split(",") device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager( self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config, cache_config=self.cfg.cache_config,
@@ -177,16 +181,15 @@ class LLMEngine(object):
device_ids=device_ids, device_ids=device_ids,
pod_ip=self.cfg.master_ip, pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix) pid_suffix=self.ipc_signal_suffix,
)
self.worker_proc = self._start_worker_service() self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...") console_logger.info("Waitting worker processes ready...")
time.sleep(5) time.sleep(5)
self.worker_init_status = dict() self.worker_init_status = dict()
if not self.check_worker_initialize_status(): if not self.check_worker_initialize_status():
console_logger.error( console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
"Failed to launch worker processes, check log/workerlog.* for more details."
)
return False return False
# Start warmup if enabled # Start warmup if enabled
@@ -199,17 +202,16 @@ class LLMEngine(object):
self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.tasks_queue = self.engine_worker_queue
self.insert_task_to_worker_thread = threading.Thread( self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread.start() self.insert_task_to_worker_thread.start()
if self.api_server_pid is not None: if self.api_server_pid is not None:
self.insert_task_to_scheduler_thread = threading.Thread( self.insert_task_to_scheduler_thread = threading.Thread(
target=self._insert_zmq_task_to_scheduler, daemon=True) target=self._insert_zmq_task_to_scheduler, daemon=True
)
self.insert_task_to_scheduler_thread.start() self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread( self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start() self.receive_output_thread.start()
# Start TokenProcessor thread # Start TokenProcessor thread
@@ -223,8 +225,7 @@ class LLMEngine(object):
self.engine_worker_queue.available_prefill_instances.put(1) self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks() self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise": if self.cfg.scheduler_config.name == "splitwise":
self.splitwise_receive_thread = threading.Thread( self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start() self.splitwise_receive_thread.start()
@@ -240,20 +241,28 @@ class LLMEngine(object):
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = [] self.dp_processed = []
for i in range(1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode): for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1) time.sleep(1)
self.dp_processed.append( self.dp_processed.append(
multiprocessing.Process(target=start_expert_service, multiprocessing.Process(
args=(self.cfg, target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node, i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix))) self.ipc_signal_suffix,
llm_logger.info(f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" \ ),
+ " data parallel id {}".format(i)) )
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start() self.dp_processed[-1].start()
console_logger.info( console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
return True return True
def _zmq_send_generated_tokens(self): def _zmq_send_generated_tokens(self):
@@ -271,8 +280,7 @@ class LLMEngine(object):
self.zmq_server.send_multipart(request_id, contents) self.zmq_server.send_multipart(request_id, contents)
except Exception as e: except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format( llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
e, str(traceback.format_exc())))
def _get_generated_result(self): def _get_generated_result(self):
""" """
@@ -296,8 +304,7 @@ class LLMEngine(object):
time.sleep(0.001) time.sleep(0.001)
continue continue
if self.exist_prefill_task_signal.value[0] > 0: if self.exist_prefill_task_signal.value[0] > 0:
if self.cfg.splitwise_role == "mixed" or \ if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
self.split_connector.has_splitwise_tasks():
time.sleep(0.005) time.sleep(0.005)
continue continue
if self.engine_worker_queue.num_cache_infos() > 0: if self.engine_worker_queue.num_cache_infos() > 0:
@@ -309,17 +316,17 @@ class LLMEngine(object):
num_prefill_batch = min( num_prefill_batch = min(
int(self.resource_manager.available_batch()), int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch) self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests( tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num( available_blocks=self.resource_manager.available_block_num(),
),
block_size=self.cfg.cache_config.block_size, block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config. reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens, max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch) batch=num_prefill_batch,
)
if len(tasks) == 0: if len(tasks) == 0:
time.sleep(0.001) time.sleep(0.001)
@@ -328,16 +335,14 @@ class LLMEngine(object):
current_id = (current_id + 1) % 100003 current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks( self.split_connector.send_splitwise_tasks(tasks, current_id)
tasks, current_id)
self.insert_tasks(tasks, current_id) self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e: except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format( err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
e, str(traceback.format_exc()))
llm_logger.error(err_msg) llm_logger.error(err_msg)
def _insert_zmq_task_to_scheduler(self): def _insert_zmq_task_to_scheduler(self):
@@ -353,8 +358,7 @@ class LLMEngine(object):
else: else:
err, data = self.zmq_server.receive_pyobj_once(block) err, data = self.zmq_server.receive_pyobj_once(block)
if err is not None: if err is not None:
llm_logger.error( llm_logger.error("Engine stops inserting zmq task into scheduler")
"Engine stops inserting zmq task into scheduler")
break break
request, insert_task = None, [] request, insert_task = None, []
@@ -363,13 +367,11 @@ class LLMEngine(object):
request = Request.from_dict(data) request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
llm_logger.debug(f"Receive request: {request}") llm_logger.debug(f"Receive request: {request}")
err_msg = None err_msg = None
if self.guided_decoding_checker is not None: if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format( request, err_msg = self.guided_decoding_checker.schema_format(request)
request)
if err_msg is not None: if err_msg is not None:
llm_logger.error(err_msg) llm_logger.error(err_msg)
@@ -394,17 +396,20 @@ class LLMEngine(object):
main_process_metrics.num_requests_waiting.inc(1) main_process_metrics.num_requests_waiting.inc(1)
continue continue
error_result = RequestOutput(request_id=request_id, error_result = RequestOutput(
request_id=request_id,
finished=True, finished=True,
error_code=500, error_code=500,
error_msg=failed) error_msg=failed,
)
# Since the request is not in scheduler # Since the request is not in scheduler
# Send result by zmq directly # Send result by zmq directly
self.zmq_server.send_multipart(request_id, error_result) self.zmq_server.send_multipart(request_id, error_result)
except Exception as e: except Exception as e:
llm_logger.error( llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, " f"Error happend while receving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}") f"traceback={traceback.format_exc()}"
)
def add_requests(self, task, sampling_params=None, **kwargs): def add_requests(self, task, sampling_params=None, **kwargs):
""" """
@@ -428,23 +433,25 @@ class LLMEngine(object):
enable_thinking = None enable_thinking = None
if kwargs is not None: if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None) enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request( request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request.prompt_token_ids_len = len(request.prompt_token_ids) request.prompt_token_ids_len = len(request.prompt_token_ids)
input_ids_len = request.prompt_token_ids_len input_ids_len = request.prompt_token_ids_len
request.set( request.set(
"max_tokens", "max_tokens",
min(self.cfg.max_model_len - input_ids_len, min(
request.get("max_tokens"))) self.cfg.max_model_len - input_ids_len,
request.get("max_tokens"),
),
)
if request.get("reasoning_max_tokens") is None: if request.get("reasoning_max_tokens") is None:
default_reasoning_max_tokens = max( default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
int(request.get("max_tokens") * 0.8), 1)
request.set("reasoning_max_tokens", default_reasoning_max_tokens) request.set("reasoning_max_tokens", default_reasoning_max_tokens)
min_tokens = request.get("min_tokens") min_tokens = request.get("min_tokens")
if input_ids_len + min_tokens >= self.cfg.max_model_len: if input_ids_len + min_tokens >= self.cfg.max_model_len:
error_msg = ( error_msg = (
f"Input text is too long, length of prompt token({input_ids_len}) " f"Input text is too long, length of prompt token({input_ids_len}) "
f"+ min_dec_len ({min_tokens}) >= max_model_len ") f"+ min_dec_len ({min_tokens}) >= max_model_len "
)
llm_logger.error(error_msg) llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400) raise EngineError(error_msg, error_code=400)
@@ -456,16 +463,14 @@ class LLMEngine(object):
raise EngineError(error_msg, error_code=400) raise EngineError(error_msg, error_code=400)
if self.guided_decoding_checker is not None: if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format( request, err_msg = self.guided_decoding_checker.schema_format(request)
request)
if err_msg is not None: if err_msg is not None:
llm_logger.error(err_msg) llm_logger.error(err_msg)
raise EngineError(err_msg, error_code=400) raise EngineError(err_msg, error_code=400)
request.preprocess_end_time = time.time() request.preprocess_end_time = time.time()
self.scheduler.put_requests([request]) self.scheduler.put_requests([request])
llm_logger.info( llm_logger.info(f"Cache task with request_id ({request.get('request_id')})")
f"Cache task with request_id ({request.get('request_id')})")
llm_logger.debug(f"cache task: {request}") llm_logger.debug(f"cache task: {request}")
def warmup(self): def warmup(self):
@@ -486,25 +491,19 @@ class LLMEngine(object):
processed_indices = [] processed_indices = []
for idx, task in enumerate(self.waiting_requests): for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient( if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
llm_logger.info( llm_logger.info(f"Resource available, processing task {task.request_id}")
f"Resource available, processing task {task.request_id}"
)
processed_indices.append(idx) processed_indices.append(idx)
else: else:
llm_logger.debug( llm_logger.debug(f"Still waiting for resources {task.request_id}")
f"Still waiting for resources {task.request_id}"
)
break break
for idx in sorted(processed_indices, reverse=True): for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx) self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty(): if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks( items = self.engine_worker_queue.get_disaggregated_tasks()
)
for item in items: for item in items:
role = item[0] role = item[0]
tasks = item[1] tasks = item[1]
@@ -515,7 +514,7 @@ class LLMEngine(object):
self.insert_tasks(tasks) self.insert_tasks(tasks)
elif role == "decode": elif role == "decode":
if hasattr(tasks[0], 'finished'): if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
for task in tasks: for task in tasks:
@@ -527,25 +526,19 @@ class LLMEngine(object):
else: else:
if len(self.waiting_requests): if len(self.waiting_requests):
llm_logger.info( llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
f"Waiting for resource for task {tasks[0].request_id}"
)
self.waiting_requests.extend(tasks) self.waiting_requests.extend(tasks)
else: else:
new_waiting = [] new_waiting = []
for task in tasks: for task in tasks:
if self.resource_manager.is_resource_sufficient( if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
else: else:
new_waiting.append(task) new_waiting.append(task)
if new_waiting: if new_waiting:
self.waiting_requests.extend( self.waiting_requests.extend(new_waiting)
new_waiting) llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else: else:
time.sleep(0.001) time.sleep(0.001)
@@ -572,13 +565,10 @@ class LLMEngine(object):
if current_request_size[idx] <= 0: if current_request_size[idx] <= 0:
chunk_request_num -= 1 chunk_request_num -= 1
if not self.cfg.cache_config.enable_chunked_prefill or len( if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
requests) == 0:
return return
current_request_size = [ current_request_size = [request.prompt_token_ids_len for request in requests]
request.prompt_token_ids_len for request in requests
]
requests_chunk = [[] for _ in range(len(requests))] requests_chunk = [[] for _ in range(len(requests))]
chunk_request_num = len(current_request_size) chunk_request_num = len(current_request_size)
while chunk_request_num >= 1: while chunk_request_num >= 1:
@@ -588,25 +578,25 @@ class LLMEngine(object):
continue continue
chunk_size = min( chunk_size = min(
current_request_size[idx], current_request_size[idx],
self.partial_chunked_tokens[chunk_request_num]) self.partial_chunked_tokens[chunk_request_num],
)
update_tokens(idx, chunk_size) update_tokens(idx, chunk_size)
while remain_batched_tokens >= self.cfg.cache_config.block_size: while remain_batched_tokens >= self.cfg.cache_config.block_size:
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求 # 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
waiting_requests = [ waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
input_lens for input_lens in current_request_size
if input_lens > 0
]
if len(waiting_requests) == 0: if len(waiting_requests) == 0:
break break
available_tokens = remain_batched_tokens // self.cfg.cache_config.block_size * \ available_tokens = (
self.cfg.cache_config.block_size remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
)
append_idx = current_request_size.index(min(waiting_requests)) append_idx = current_request_size.index(min(waiting_requests))
chunk_size = min( chunk_size = min(
current_request_size[append_idx], current_request_size[append_idx],
self.partial_chunked_tokens[chunk_request_num], self.partial_chunked_tokens[chunk_request_num],
available_tokens) available_tokens,
)
update_tokens(append_idx, chunk_size, update_chunk=True) update_tokens(append_idx, chunk_size, update_chunk=True)
for idx in range(len(requests)): for idx in range(len(requests)):
@@ -616,8 +606,7 @@ class LLMEngine(object):
""" """
update each multimodal request's chunk size info update each multimodal request's chunk size info
""" """
if not self.cfg.cache_config.enable_chunked_prefill or len( if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
requests) == 0:
return return
for request in requests: for request in requests:
@@ -628,12 +617,9 @@ class LLMEngine(object):
inputs["grid_thw"] = np.array([], dtype="int64") inputs["grid_thw"] = np.array([], dtype="int64")
inputs["images"] = np.array([], dtype="uint8") inputs["images"] = np.array([], dtype="uint8")
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
dtype="int32")
image_mask = input_ids == self.data_processor.image_patch_id image_mask = input_ids == self.data_processor.image_patch_id
image_token_sum = paddle.full(shape=[len(input_ids) + 1], image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
fill_value=0,
dtype="int32")
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32")) image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"))
grid_thw = [] grid_thw = []
for one in inputs["grid_thw"]: for one in inputs["grid_thw"]:
@@ -644,45 +630,46 @@ class LLMEngine(object):
grid_thw = paddle.to_tensor(grid_thw, dtype="int64") grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
chunk_image_num, chunk_seq_len = get_mm_split_fuse( chunk_image_num, chunk_seq_len = get_mm_split_fuse(
input_ids, image_type_ids, image_token_sum, grid_thw, input_ids,
self.data_processor.image_patch_id, len(grid_thw), 0, image_type_ids,
len(input_ids), 0, self.partial_chunked_tokens[1], 2048) image_token_sum,
grid_thw,
self.data_processor.image_patch_id,
len(grid_thw),
0,
len(input_ids),
0,
self.partial_chunked_tokens[1],
2048,
)
grid_thw = grid_thw.numpy().reshape([-1, 3]) grid_thw = grid_thw.numpy().reshape([-1, 3])
num_chunks = len(chunk_image_num) num_chunks = len(chunk_image_num)
chunks_info = [] chunks_info = []
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0 input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
for idx in range(num_chunks): for idx in range(num_chunks):
chunk_input_ids = inputs["input_ids"][ chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
input_ids_st:input_ids_st + chunk_seq_len[idx]] chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][ actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
input_ids_st:input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st:grid_thw_st +
chunk_image_num[idx], 0])
chunk_image_type_ids = inputs["image_type_ids"][ chunk_image_type_ids = inputs["image_type_ids"][
image_type_ids_st:image_type_ids_st + actual_image_num] image_type_ids_st : image_type_ids_st + actual_image_num
chunk_grid_thw = grid_thw[grid_thw_st:grid_thw_st + ]
chunk_image_num[idx]] chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1)) chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
chunk_images = inputs["images"][patch_st:patch_st + chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
chunk_patch_num]
chunks_info.append({ chunks_info.append(
"input_ids": {
chunk_input_ids, "input_ids": chunk_input_ids,
"token_type_ids": "token_type_ids": chunk_token_type_ids,
chunk_token_type_ids, "image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
"image_type_ids": "grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
chunk_image_type_ids "images": (chunk_images if chunk_images.shape[0] else None),
if chunk_image_type_ids.shape[0] else None, "position_ids": None,
"grid_thw": }
chunk_grid_thw if chunk_grid_thw.shape[0] else None, )
"images":
chunk_images if chunk_images.shape[0] else None,
"position_ids":
None
})
input_ids_st += chunk_seq_len[idx] input_ids_st += chunk_seq_len[idx]
image_type_ids_st += actual_image_num image_type_ids_st += actual_image_num
@@ -704,18 +691,14 @@ class LLMEngine(object):
del self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx] cur_task = self.resource_manager.tasks_list[cur_task_idx]
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in [ if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
"mtp" cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(
task.outputs.draft_token_ids)
if task.error_code != 200: if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task) self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter: if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[ del self.token_processor.tokens_counter[task.request_id]
task.request_id]
self.scheduler.put_results([task]) self.scheduler.put_results([task])
llm_logger.warning( llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
@@ -723,8 +706,7 @@ class LLMEngine(object):
continue continue
self.token_processor.tokens_counter[task.request_id] = 1 self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task) current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks( self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
(current_tasks, self.resource_manager.real_bsz))
return True return True
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
@@ -737,9 +719,7 @@ class LLMEngine(object):
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
llm_logger.error( llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error("The exceeded part will be ignored!") llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
@@ -763,8 +743,7 @@ class LLMEngine(object):
is_decode = True is_decode = True
else: else:
is_prefill = True is_prefill = True
self.token_processor.number_of_input_tokens += tasks[ self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id) self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode: if not is_decode:
@@ -776,8 +755,7 @@ class LLMEngine(object):
self.update_requests_chunk_size(tasks) self.update_requests_chunk_size(tasks)
else: else:
self.update_mm_requests_chunk_size(tasks) self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks( self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
(tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise": if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1) self.engine_worker_queue.available_prefill_instances.put(1)
return True return True
@@ -793,8 +771,7 @@ class LLMEngine(object):
""" """
judge if all tasks are finished judge if all tasks are finished
""" """
return np.sum(self.resource_manager.stop_flags) == len( return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
self.resource_manager.stop_flags)
def _set_warmup_token_processor(self): def _set_warmup_token_processor(self):
""" """
@@ -824,8 +801,7 @@ class LLMEngine(object):
judge if all worker processes are ready judge if all worker processes are ready
""" """
if np.sum(self.worker_ready_signal.value if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node:
) == self.cfg.worker_num_per_node:
return True return True
return False return False
@@ -835,30 +811,33 @@ class LLMEngine(object):
""" """
# worker_ready_signatensor_parallel_size # worker_ready_signatensor_parallel_size
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_ready_signal = IPCSignal(name="worker_ready_signal", self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=worker_ready_signal_data, array=worker_ready_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
# exist_task_signal 用于各worker进程感知是否有新Task需要处理 # exist_task_signal 用于各worker进程感知是否有新Task需要处理
exist_task_signal_data = np.zeros( exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
[self.cfg.parallel_config.data_parallel_size], dtype=np.int32) self.exist_task_signal = IPCSignal(
self.exist_task_signal = IPCSignal(name="exist_task_signal", name="exist_task_signal",
array=exist_task_signal_data, array=exist_task_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
exist_swapped_task_signal_data = np.zeros( exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
[self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal( self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal", name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data, array=exist_swapped_task_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill # exist_prefill_task_signal 用于各worker进程感知是否进行prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
@@ -867,17 +846,18 @@ class LLMEngine(object):
array=exist_prefill_task_signal_data, array=exist_prefill_task_signal_data,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间 # worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal( self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal", name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array, array=worker_healthy_live_recorded_time_array,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
if self.do_profile: if self.do_profile:
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
@@ -886,7 +866,8 @@ class LLMEngine(object):
array=get_profile_block_num, array=get_profile_block_num,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32) model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal( self.model_weights_status_signal = IPCSignal(
@@ -894,7 +875,8 @@ class LLMEngine(object):
array=model_weights_status, array=model_weights_status,
dtype=np.int32, dtype=np.int32,
suffix=self.ipc_signal_suffix, suffix=self.ipc_signal_suffix,
create=True) create=True,
)
def _exit_sub_services(self): def _exit_sub_services(self):
""" """
@@ -903,8 +885,7 @@ class LLMEngine(object):
self.running = False self.running = False
if hasattr(self, "cache_manager_processes"): if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear( self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
)
self.resource_manager.cache_manager.cache_ready_signal.clear() self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") llm_logger.info(f"Killing cache manager process {p.pid}")
@@ -943,7 +924,7 @@ class LLMEngine(object):
"TRAINER_INSTANCES_NUM": 1, "TRAINER_INSTANCES_NUM": 1,
"TRAINER_INSTANCES": "0.0.0.0", "TRAINER_INSTANCES": "0.0.0.0",
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(',')), "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(",")),
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1, "FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring", "NCCL_ALGO": "Ring",
@@ -951,24 +932,22 @@ class LLMEngine(object):
"FLAGS_hardamard_moe_block_size": 128, "FLAGS_hardamard_moe_block_size": 128,
} }
# environment variables needed by Dy2St # environment variables needed by Dy2St
variables.update({ variables.update(
"SOT_LOG_LEVEL": {
os.getenv("SOT_LOG_LEVEL", default="0"), "SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_ENABLE_0_SIZE_FALLBACK": "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_specialize_device_in_dy2st": "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_pir_interpreter_record_stream_for_gc_cache",
"FLAGS_enable_async_fast_gc": default="1",
os.getenv("FLAGS_enable_async_fast_gc", default="0"), ),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": "FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
os.getenv("FLAGS_pir_interpreter_record_stream_for_gc_cache", "FLAGS_parameters_persistent_mode_in_dy2st", default="1"
default="1"), ),
"FLAGS_parameters_persistent_mode_in_dy2st": }
os.getenv("FLAGS_parameters_persistent_mode_in_dy2st", )
default="1"),
})
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
variables["FLAGS_use_pd_disaggregation"] = 1 variables["FLAGS_use_pd_disaggregation"] = 1
@@ -994,8 +973,7 @@ class LLMEngine(object):
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.split(current_file_path)[0] current_dir_path = os.path.split(current_file_path)[0]
# TODO # TODO
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == 1 else "-u"
"0") == 1 else "-u"
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch" pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch"
pd_cmd = pd_cmd + f" --log_dir {log_dir}" pd_cmd = pd_cmd + f" --log_dir {log_dir}"
@@ -1004,7 +982,7 @@ class LLMEngine(object):
ori_vocab_size = ( ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model) len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model') if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab) else len(self.data_processor.tokenizer.vocab)
) )
@@ -1012,10 +990,10 @@ class LLMEngine(object):
f" --devices {self.cfg.device_ids} {py_script}" f" --devices {self.cfg.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}" f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
f" --model_name_or_path {str(self.cfg.model_name_or_path)}" f" --model_name_or_path {self.cfg.model_name_or_path!s}"
f" --device_ids {self.cfg.device_ids}" f" --device_ids {self.cfg.device_ids}"
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}" f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}" f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}"
f" --pod_ip {self.cfg.master_ip}" f" --pod_ip {self.cfg.master_ip}"
f" --total_block_num {self.cfg.cache_config.total_block_num}" f" --total_block_num {self.cfg.cache_config.total_block_num}"
f" --block_size {self.cfg.cache_config.block_size}" f" --block_size {self.cfg.cache_config.block_size}"
@@ -1036,16 +1014,13 @@ class LLMEngine(object):
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}" f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}") f" --load_strategy {self.cfg.model_config.load_strategy}"
)
worker_append_flag = { worker_append_flag = {
"enable_expert_parallel": "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
self.cfg.parallel_config.enable_expert_parallel, "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_prefix_caching": "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill":
self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile, "do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight, "dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.disable_any_whitespace, "disable_any_whitespace": self.cfg.disable_any_whitespace,
@@ -1059,11 +1034,11 @@ class LLMEngine(object):
if self.cfg.nnode > 1: if self.cfg.nnode > 1:
pd_cmd = pd_cmd + ( pd_cmd = pd_cmd + (
f" --master {self.cfg.dist_init_addr}" f" --master {self.cfg.dist_init_addr}"
f" --nnodes {str(self.cfg.nnode)}" f" --nnodes {self.cfg.nnode!s}"
f" --rank {str(self.cfg.node_rank)}" f" --rank {self.cfg.node_rank!s}"
) )
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
llm_logger.info("Launch worker service command: {}".format(pd_cmd)) llm_logger.info(f"Launch worker service command: {pd_cmd}")
p = subprocess.Popen( p = subprocess.Popen(
pd_cmd, pd_cmd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
@@ -1111,8 +1086,7 @@ class LLMEngine(object):
try: try:
req_id = self._format_and_add_data(prompts) req_id = self._format_and_add_data(prompts)
except Exception as e: except Exception as e:
llm_logger.error( llm_logger.error(f"Error happend while adding request, details={e}")
f"Error happend while adding request, details={e}")
raise EngineError(str(e), error_code=400) raise EngineError(str(e), error_code=400)
# 获取当前请求的结果 # 获取当前请求的结果
@@ -1151,8 +1125,7 @@ class LLMEngine(object):
if num_gpu_blocks < 0: if num_gpu_blocks < 0:
num_gpu_blocks = self.get_profile_block_num_signal.value[i] num_gpu_blocks = self.get_profile_block_num_signal.value[i]
else: else:
num_gpu_blocks = min( num_gpu_blocks = min(num_gpu_blocks, self.get_profile_block_num_signal.value[i])
num_gpu_blocks, self.get_profile_block_num_signal.value[i])
self.cfg.cache_config.reset(num_gpu_blocks) self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config) self.resource_manager.reset_cache_config(self.cfg.cache_config)
@@ -1164,15 +1137,16 @@ class LLMEngine(object):
device_ids=device_ids, device_ids=device_ids,
pod_ip=self.cfg.master_ip, pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix) pid_suffix=self.ipc_signal_suffix,
)
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.
""" """
if self.worker_healthy_live_signal.value[0]: if self.worker_healthy_live_signal.value[0]:
elapsed_time = time.time() - \ elapsed_time = time.time() - self.worker_healthy_live_signal.value[0]
self.worker_healthy_live_signal.value[0]
if elapsed_time > time_interval_threashold: if elapsed_time > time_interval_threashold:
return False, "Worker Service Not Healthy" return False, "Worker Service Not Healthy"
@@ -1185,38 +1159,31 @@ class LLMEngine(object):
def detect_thread(): def detect_thread():
for line in self.worker_proc.stdout: for line in self.worker_proc.stdout:
line = line.decode('utf-8', errors='ignore') line = line.decode("utf-8", errors="ignore")
if self.worker_init_status.get("finished", False): if self.worker_init_status.get("finished", False):
break break
if match := re.search( if match := re.search(
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)', r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)",
line): line,
self.worker_init_status["weight_loadding"] = eval( ):
match.group(1)) * 1.0 / 100 self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100
elif (match := re.search(r'Start load layer (\d+)', elif (match := re.search(r"Start load layer (\d+)", line)) or (
line)) or (match := re.search( match := re.search(r"set state for layer (\d+)", line)
r'set state for layer (\d+)', ):
line)): progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_layers
progress = eval(match.group(
1)) * 1.0 / self.cfg.model_config.num_layers
self.worker_init_status["layer_loadding"] = progress self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status[ if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1:
"layer_loadding"] == self.cfg.model_config.num_layers - 1:
self.worker_init_status["finished"] = True self.worker_init_status["finished"] = True
self.checking_worker_status_thread = threading.Thread( self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
target=detect_thread, daemon=True)
self.checking_worker_status_thread.start() self.checking_worker_status_thread.start()
# display weight loadding progress # display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar: with tqdm(total=100, desc="Loading Weights") as pbar:
progress = 0 progress = 0
while progress < 100: while progress < 100:
progress = int( progress = int(self.worker_init_status.get("weight_loadding", 0) * 100)
self.worker_init_status.get("weight_loadding", 0) * 100) if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready():
if self.worker_init_status.get(
"layer_loadding",
0) > 0 or self._worker_processes_ready():
progress = 100 progress = 100
pbar.update(progress - pbar.n) pbar.update(progress - pbar.n)
pbar.refresh() pbar.refresh()
@@ -1228,8 +1195,7 @@ class LLMEngine(object):
with tqdm(total=100, desc="Loading Layers") as pbar: with tqdm(total=100, desc="Loading Layers") as pbar:
progress = 0 progress = 0
while progress < 100: while progress < 100:
progress = int( progress = int(self.worker_init_status.get("layer_loadding", 0) * 100)
self.worker_init_status.get("layer_loadding", 0) * 100)
if self._worker_processes_ready(): if self._worker_processes_ready():
progress = 100 progress = 100
pbar.update(progress - pbar.n) pbar.update(progress - pbar.n)
@@ -1256,19 +1222,21 @@ class LLMEngine(object):
address=address, address=address,
is_server=True, is_server=True,
num_client=self.cfg.tensor_parallel_size, num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config. local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
data_parallel_size) )
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed': if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.cache_queue_port), address=(
authkey=b'cache_queue_service', self.cfg.master_ip,
self.cfg.cache_config.cache_queue_port,
),
authkey=b"cache_queue_service",
is_server=True, is_server=True,
num_client=self.cfg.tensor_parallel_size, num_client=self.cfg.tensor_parallel_size,
client_id=-1, client_id=-1,
local_data_parallel_size=self.cfg.parallel_config. local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
data_parallel_size) )
self.engine_worker_queue = EngineWorkerQueue( self.engine_worker_queue = EngineWorkerQueue(
address=address, address=address,
@@ -1276,5 +1244,8 @@ class LLMEngine(object):
num_client=self.cfg.tensor_parallel_size, num_client=self.cfg.tensor_parallel_size,
client_id=0, client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id= min(self.cfg.worker_num_per_node * self.cfg.node_rank, local_data_parallel_id=min(
self.cfg.parallel_config.data_parallel_size - 1)) self.cfg.worker_num_per_node * self.cfg.node_rank,
self.cfg.parallel_config.data_parallel_size - 1,
),
)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
from __future__ import annotations from __future__ import annotations
import os import os
@@ -32,7 +33,7 @@ from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger from fastdeploy.utils import EngineError, console_logger, llm_logger
class ExpertService(object): class ExpertService:
""" """
Engine class responsible for managing the Large Language Model (LLM) operations. Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -51,17 +52,14 @@ class ExpertService(object):
self.cfg = cfg self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[ self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
start_pos:end_pos] self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(
",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.cfg.disaggregate_info = None self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid( self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
f"{self.scheduler.infer.nodeid}_{str(local_data_parallel_id)}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -73,33 +71,41 @@ class ExpertService(object):
num_client=cfg.tensor_parallel_size, num_client=cfg.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id, local_data_parallel_id=local_data_parallel_id,
) )
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, \ self.resource_manager = ResourceManager(
cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id) cfg.max_num_seqs,
cfg,
cfg.tensor_parallel_size,
cfg.splitwise_role,
local_data_parallel_id,
)
if len(self.cfg.cache_config.pd_comm_port) == 1: if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = int( self.cfg.cache_config.pd_comm_port[0] = int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
else: else:
self.cfg.cache_config.pd_comm_port = [ self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.cfg.cache_config.pd_comm_port[local_data_parallel_id]
]
self.split_connector = SplitwiseConnector(self.cfg, self.scheduler, self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue, self.engine_worker_queue,
self.resource_manager) self.resource_manager,
)
self.token_processor = TokenProcessor( self.token_processor = TokenProcessor(
cfg=cfg, cfg=cfg,
cached_generated_tokens=self.scheduler, cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue, engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector) split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager) self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * ( self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1): for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \ self.partial_chunked_tokens[idx] = (
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size (self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self._finalizer = weakref.finalize(self, self._exit_sub_services) self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -120,17 +126,15 @@ class ExpertService(object):
device_ids=self.cfg.local_device_ids, device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip, pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port, engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}" pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
) )
self.insert_task_to_worker_thread = threading.Thread( self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start() self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread # Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str( os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run() self.token_processor.run()
@@ -144,9 +148,7 @@ class ExpertService(object):
self.scheduler.start(role, host_ip, disaggregate) self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print() self.cfg.print()
console_logger.info( console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
return True return True
def _insert_task_to_worker(self): def _insert_task_to_worker(self):
@@ -169,17 +171,17 @@ class ExpertService(object):
num_prefill_batch = min( num_prefill_batch = min(
int(self.resource_manager.available_batch()), int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch) self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests( tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num( available_blocks=self.resource_manager.available_block_num(),
),
block_size=self.cfg.cache_config.block_size, block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config. reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens, max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch) batch=num_prefill_batch,
)
if len(tasks) == 0: if len(tasks) == 0:
time.sleep(0.001) time.sleep(0.001)
@@ -187,8 +189,7 @@ class ExpertService(object):
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks( self.split_connector.send_splitwise_tasks(tasks, current_id)
tasks, current_id)
current_id = (current_id + 1) % 100003 current_id = (current_id + 1) % 100003
@@ -197,8 +198,7 @@ class ExpertService(object):
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e: except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format( err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
e, str(traceback.format_exc()))
llm_logger.error(err_msg) llm_logger.error(err_msg)
def split_mode_get_tasks(self): def split_mode_get_tasks(self):
@@ -212,15 +212,13 @@ class ExpertService(object):
try: try:
if len(waiting_requests) > 0: if len(waiting_requests) > 0:
for task in waiting_requests: for task in waiting_requests:
if self.resource_manager.is_resource_sufficient( if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
waiting_requests.remove(task) waiting_requests.remove(task)
else: else:
break break
if not self.engine_worker_queue.disaggregate_queue_empty(): if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks( items = self.engine_worker_queue.get_disaggregated_tasks()
)
for item in items: for item in items:
role = item[0] role = item[0]
tasks = item[1] tasks = item[1]
@@ -231,7 +229,7 @@ class ExpertService(object):
self.insert_tasks(tasks) self.insert_tasks(tasks)
elif role == "decode": elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}") llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], 'finished'): if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
for task in tasks: for task in tasks:
@@ -246,7 +244,8 @@ class ExpertService(object):
else: else:
for task in tasks: for task in tasks:
if not self.resource_manager.is_resource_sufficient( if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len): task.prompt_token_ids_len
):
waiting_requests.append(task) waiting_requests.append(task)
else: else:
self.insert_tasks([task]) self.insert_tasks([task])
@@ -274,8 +273,7 @@ class ExpertService(object):
self.resource_manager.tasks_list[cur_task_idx] = None self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task) self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter: if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[ del self.token_processor.tokens_counter[task.request_id]
task.request_id]
self.scheduler.put_results([task]) self.scheduler.put_results([task])
llm_logger.warning( llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource." f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
@@ -285,8 +283,7 @@ class ExpertService(object):
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1 self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task) current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks( self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
(current_tasks, self.resource_manager.real_bsz))
return True return True
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
@@ -299,9 +296,7 @@ class ExpertService(object):
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
llm_logger.error( llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error("The exceeded part will be ignored!") llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
@@ -325,8 +320,7 @@ class ExpertService(object):
is_decode = True is_decode = True
else: else:
is_prefill = True is_prefill = True
self.token_processor.number_of_input_tokens += tasks[ self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id) self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks: for task in tasks:
@@ -338,8 +332,7 @@ class ExpertService(object):
self.update_requests_chunk_size(tasks) self.update_requests_chunk_size(tasks)
else: else:
self.update_mm_requests_chunk_size(tasks) self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks( self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
(tasks, self.resource_manager.real_bsz))
return True return True
def _exit_sub_services(self): def _exit_sub_services(self):
@@ -348,8 +341,7 @@ class ExpertService(object):
""" """
if hasattr(self, "cache_manager_processes"): if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear( self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
)
self.resource_manager.cache_manager.cache_ready_signal.clear() self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes: for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}") llm_logger.info(f"Killing cache manager process {p.pid}")

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import list from typing import list
@@ -25,6 +26,7 @@ class KVCacheSpec:
""" """
A base class for specifying the KV cache format of one layer. A base class for specifying the KV cache format of one layer.
""" """
# number of tokens in a block # number of tokens in a block
block_size: int block_size: int
# the memory size used by each block in bytes. # the memory size used by each block in bytes.
@@ -37,10 +39,9 @@ class KVCacheSpec:
""" """
# check list # check list
assert all( assert all(
(spec.block_size == specs[0].block_size (spec.block_size == specs[0].block_size and spec.block_memory_used == specs[0].block_memory_used)
and spec.block_memory_used == specs[0].block_memory_used) for spec in specs[1:]
for spec in specs[1:]), ( ), "All layers in the model must share the same block_size."
"All layers in the model must share the same block_size.")
return copy.deepcopy(specs[0]) return copy.deepcopy(specs[0])
@@ -48,6 +49,7 @@ class KVCacheSpec:
@dataclass @dataclass
class AttentionSpec(KVCacheSpec): class AttentionSpec(KVCacheSpec):
""" """ """ """
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: str dtype: str

View File

@@ -29,8 +29,8 @@ from fastdeploy.worker.output import LogprobsLists
@dataclass @dataclass
class Request: class Request:
def __init__(
def __init__(self, self,
request_id: str, request_id: str,
prompt: Optional[Union[str, list[str]]], prompt: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[list[int]], prompt_token_ids: Optional[list[int]],
@@ -56,7 +56,8 @@ class Request:
structural_tag: Optional[Any] = None, structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None, guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True, enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict()) -> None: trace_carrier: dict = dict(),
) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
@@ -98,7 +99,8 @@ class Request:
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
data_processor_logger.debug(f"{d}") data_processor_logger.debug(f"{d}")
sampling_params = SamplingParams.from_dict(d) sampling_params = SamplingParams.from_dict(d)
return cls(request_id=d["request_id"], return cls(
request_id=d["request_id"],
prompt=d.get("prompt"), prompt=d.get("prompt"),
prompt_token_ids=d.get("prompt_token_ids"), prompt_token_ids=d.get("prompt_token_ids"),
prompt_token_ids_len=d.get("prompt_token_ids_len"), prompt_token_ids_len=d.get("prompt_token_ids_len"),
@@ -123,7 +125,8 @@ class Request:
structural_tag=d.get("structural_tag", None), structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None), guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True), enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {})) trace_carrier=d.get("trace_carrier", {}),
)
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""convert Request into a serializable dict""" """convert Request into a serializable dict"""
@@ -146,11 +149,15 @@ class Request:
"disaggregate_info": self.disaggregate_info, "disaggregate_info": self.disaggregate_info,
"draft_token_ids": self.draft_token_ids, "draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking, "enable_thinking": self.enable_thinking,
"trace_carrier": self.trace_carrier "trace_carrier": self.trace_carrier,
} }
add_params = [ add_params = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar", "guided_json",
"structural_tag", "guided_json_object" "guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
"guided_json_object",
] ]
for param in add_params: for param in add_params:
if getattr(self, param, None) is not None: if getattr(self, param, None) is not None:
@@ -174,11 +181,13 @@ class Request:
setattr(self, key, value) setattr(self, key, value)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"Request(request_id={self.request_id}, " return (
f"Request(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"draft_token_ids={self.draft_token_ids}, " f"draft_token_ids={self.draft_token_ids}, "
f"sampling_params={self.sampling_params})") f"sampling_params={self.sampling_params})"
)
@dataclass(slots=True) @dataclass(slots=True)
@@ -212,27 +221,28 @@ class CompletionOutput:
"top_logprobs": self.top_logprobs, "top_logprobs": self.top_logprobs,
"draft_token_ids": self.draft_token_ids, "draft_token_ids": self.draft_token_ids,
"text": self.text, "text": self.text,
"reasoning_content": self.reasoning_content "reasoning_content": self.reasoning_content,
} }
@classmethod @classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'CompletionOutput': def from_dict(cls, req_dict: dict[str, Any]) -> CompletionOutput:
"""Create instance from dict arguments""" """Create instance from dict arguments"""
return cls( return cls(
**{ **{
field.name: field.name: (req_dict[field.name] if field.name in req_dict else field.default)
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls) for field in fields(cls)
}) }
)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, " return (
f"CompletionOutput(index={self.index}, "
f"send_idx={self.send_idx}, " f"send_idx={self.send_idx}, "
f"text={self.text!r}, " f"text={self.text!r}, "
f"token_ids={self.token_ids}, " f"token_ids={self.token_ids}, "
f"draft_token_ids={self.draft_token_ids}, " f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}") f"reasoning_content={self.reasoning_content!r}"
)
@dataclass(slots=True) @dataclass(slots=True)
@@ -252,6 +262,7 @@ class RequestMetrics:
request_start_time: Time to accept the request request_start_time: Time to accept the request
""" """
arrival_time: float arrival_time: float
inference_start_time: Optional[float] = None inference_start_time: Optional[float] = None
first_token_time: Optional[float] = None first_token_time: Optional[float] = None
@@ -273,19 +284,18 @@ class RequestMetrics:
"preprocess_cost_time": self.preprocess_cost_time, "preprocess_cost_time": self.preprocess_cost_time,
"model_forward_time": self.model_forward_time, "model_forward_time": self.model_forward_time,
"model_execute_time": self.model_execute_time, "model_execute_time": self.model_execute_time,
"request_start_time": self.request_start_time "request_start_time": self.request_start_time,
} }
@classmethod @classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'RequestMetrics': def from_dict(cls, req_dict: dict[str, Any]) -> RequestMetrics:
"""Create instance from dict arguments""" """Create instance from dict arguments"""
return cls( return cls(
**{ **{
field.name: field.name: (req_dict[field.name] if field.name in req_dict else field.default)
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls) for field in fields(cls)
}) }
)
class RequestOutput: class RequestOutput:
@@ -333,13 +343,12 @@ class RequestOutput:
self.error_code = error_code self.error_code = error_code
self.error_msg = error_msg self.error_msg = error_msg
if prompt_token_ids is None: if prompt_token_ids is None:
self.prompt_token_ids = [] self.prompt_token_ids = []
elif isinstance(self.prompt_token_ids, np.ndarray): elif isinstance(self.prompt_token_ids, np.ndarray):
self.prompt_token_ids = self.prompt_token_ids.tolist() self.prompt_token_ids = self.prompt_token_ids.tolist()
def add(self, next_output: "RequestOutput") -> None: def add(self, next_output: RequestOutput) -> None:
"""Merge RequestOutput into this one""" """Merge RequestOutput into this one"""
self.prompt = next_output.prompt self.prompt = next_output.prompt
@@ -348,19 +357,19 @@ class RequestOutput:
self.outputs.index = next_output.outputs.index self.outputs.index = next_output.outputs.index
self.outputs.token_ids.extend(next_output.outputs.token_ids) self.outputs.token_ids.extend(next_output.outputs.token_ids)
if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None: if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None:
self.metrics.model_forward_time = next_output.metrics.arrival_time - \ self.metrics.model_forward_time = next_output.metrics.arrival_time - self.metrics.inference_start_time
self.metrics.inference_start_time
if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None: if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None:
self.metrics.model_execute_time = next_output.metrics.arrival_time - \ self.metrics.model_execute_time = next_output.metrics.arrival_time - self.metrics.arrival_time
self.metrics.arrival_time
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"metrics={self.metrics}, " f"metrics={self.metrics}, "
f"num_cached_tokens={self.num_cached_tokens})") f"num_cached_tokens={self.num_cached_tokens})"
)
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
@@ -376,10 +385,8 @@ class RequestOutput:
"request_id": self.request_id, "request_id": self.request_id,
"prompt": self.prompt, "prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids, "prompt_token_ids": self.prompt_token_ids,
"outputs": "outputs": None if self.outputs is None else self.outputs.to_dict(),
None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(),
"metrics":
None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished, "finished": self.finished,
"num_cached_tokens": self.num_cached_tokens, "num_cached_tokens": self.num_cached_tokens,
"error_code": self.error_code, "error_code": self.error_code,

View File

@@ -25,17 +25,19 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
class ResourceManager(object): class ResourceManager:
""" """
record and allocate resources for the engine record and allocate resources for the engine
""" """
def __init__(self, def __init__(
self,
max_num_seqs, max_num_seqs,
config, config,
tensor_parallel_size, tensor_parallel_size,
splitwise_role, splitwise_role,
local_data_parallel_id=0): local_data_parallel_id=0,
):
""" """
Args: Args:
cfg (Config): config object containing parameters for the engine cfg (Config): config object containing parameters for the engine
@@ -51,9 +53,7 @@ class ResourceManager(object):
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs self.stop_flags = [True] * max_num_seqs
self.enable_prefix_cache = config.cache_config.enable_prefix_caching self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
splitwise_role,
local_data_parallel_id)
self.tasks_list = [None] * max_num_seqs self.tasks_list = [None] * max_num_seqs
self.req_dict = dict() self.req_dict = dict()
# current batch status of the engine # current batch status of the engine
@@ -77,8 +77,7 @@ class ResourceManager(object):
Returns: Returns:
int: block number int: block number
""" """
block_num = (input_token_num + self.cfg.block_size - 1 + block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
self.cfg.dec_token_num) // self.cfg.block_size
return block_num return block_num
def get_encoder_block_number(self, input_token_num): def get_encoder_block_number(self, input_token_num):
@@ -91,8 +90,7 @@ class ResourceManager(object):
Returns: Returns:
int: encoder block number int: encoder block number
""" """
enc_block_num = (input_token_num + self.cfg.block_size - enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
1) // self.cfg.block_size
return enc_block_num return enc_block_num
def get_decoder_block_number(self): def get_decoder_block_number(self):
@@ -102,8 +100,7 @@ class ResourceManager(object):
Returns: Returns:
int: decoder block number int: decoder block number
""" """
return (self.cfg.dec_token_num + self.cfg.block_size - return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
1) // self.cfg.block_size
def total_block_number(self): def total_block_number(self):
""" """
@@ -132,13 +129,12 @@ class ResourceManager(object):
elif required_type == "decoder": elif required_type == "decoder":
block_num = self.get_decoder_block_number() block_num = self.get_decoder_block_number()
else: else:
raise ValueError('unknown required type') raise ValueError("unknown required type")
block_list = list() block_list = list()
current_block_num = self.available_block_num() current_block_num = self.available_block_num()
if block_num > current_block_num: if block_num > current_block_num:
llm_logger.error("block_num:{0} > free_list len:{1}".format( llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}")
block_num, current_block_num))
return block_list return block_list
block_list = self.cache_manager.allocate_gpu_blocks(block_num) block_list = self.cache_manager.allocate_gpu_blocks(block_num)
llm_logger.debug(f"dispatch {len(block_list)} blocks.") llm_logger.debug(f"dispatch {len(block_list)} blocks.")
@@ -172,10 +168,8 @@ class ResourceManager(object):
ori_number = self.available_block_num() ori_number = self.available_block_num()
self.cache_manager.recycle_gpu_blocks(block_tables) self.cache_manager.recycle_gpu_blocks(block_tables)
cur_number = self.available_block_num() cur_number = self.available_block_num()
main_process_metrics.gpu_cache_usage_perc.set( main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
self.get_gpu_cache_usage_perc()) llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
llm_logger.info(
f"recycle {req_id} {cur_number - ori_number} blocks.")
def available_batch(self): def available_batch(self):
""" """
@@ -238,8 +232,7 @@ class ResourceManager(object):
can_insert = False can_insert = False
while allocated_position + 1 <= self.max_num_seqs: while allocated_position + 1 <= self.max_num_seqs:
if sum(self.stop_flags[allocated_position:allocated_position + if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
1]) == 1:
can_insert = True can_insert = True
break break
allocated_position += 1 allocated_position += 1
@@ -249,72 +242,63 @@ class ResourceManager(object):
task = tasks[processing_task_index] task = tasks[processing_task_index]
if task.get("seed") is None: if task.get("seed") is None:
task.set("seed", task.set("seed", random.randint(0, 9223372036854775807))
random.randint(0, 9223372036854775807))
task.idx = allocated_position task.idx = allocated_position
if self.enable_prefix_cache: if self.enable_prefix_cache:
cache_prepare_time = time.time() cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids( common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task, self.cfg.block_size, self.cfg.dec_token_num) task,
self.cfg.block_size,
self.cfg.dec_token_num,
)
if unique_block_ids is None: if unique_block_ids is None:
llm_logger.warning( llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
"req_id: {0} not enough blocks available".
format(task["req_id"]))
return return
cached_len = self._record_request_cache_info( cached_len = self._record_request_cache_info(
task, common_block_ids, unique_block_ids, hit_info) task, common_block_ids, unique_block_ids, hit_info
task.cache_prepare_time = time.time( )
) - cache_prepare_time task.cache_prepare_time = time.time() - cache_prepare_time
if task.disaggregate_info is not None: if task.disaggregate_info is not None:
if task.disaggregate_info['role'] == "prefill": if task.disaggregate_info["role"] == "prefill":
self.req_dict[ self.req_dict[task.request_id] = allocated_position
task.request_id] = allocated_position task.disaggregate_info["block_tables"] = task.block_tables
task.disaggregate_info[
'block_tables'] = task.block_tables
self._delete_cached_data(task, cached_len) self._delete_cached_data(task, cached_len)
elif task.disaggregate_info['role'] == "decode": elif task.disaggregate_info["role"] == "decode":
self.req_dict[ self.req_dict[task.request_id] = allocated_position
task.request_id] = allocated_position task.disaggregate_info["block_tables"] = task.need_block_tables
task.disaggregate_info[
'block_tables'] = task.need_block_tables
else: else:
self._delete_cached_data(task, cached_len) self._delete_cached_data(task, cached_len)
else: else:
block_tables = self._get_block_tables( block_tables = self._get_block_tables(task.prompt_token_ids_len)
task.prompt_token_ids_len)
if not block_tables: if not block_tables:
llm_logger.error( llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
"req_id: {0} block_tables is empty".format(
task.request_id))
continue continue
else: else:
task.block_tables = block_tables task.block_tables = block_tables
task.need_block_tables = task.block_tables task.need_block_tables = task.block_tables
if task.disaggregate_info is not None: if task.disaggregate_info is not None:
task.disaggregate_info[ task.disaggregate_info["block_tables"] = block_tables
'block_tables'] = block_tables if task.disaggregate_info["role"] == "prefill":
if task.disaggregate_info['role'] == "prefill": self.req_dict[task.request_id] = allocated_position
self.req_dict[ elif task.disaggregate_info["role"] == "decode":
task.request_id] = allocated_position self.req_dict[task.request_id] = allocated_position
elif task.disaggregate_info['role'] == "decode":
self.req_dict[
task.request_id] = allocated_position
processed_tasks.append(task) processed_tasks.append(task)
self.stop_flags[allocated_position] = False self.stop_flags[allocated_position] = False
task.inference_start_time = time.time() task.inference_start_time = time.time()
task.inference_time_cost = -1.0 task.inference_time_cost = -1.0
task.tokens_all_num = int(0) task.tokens_all_num = 0
self.tasks_list[allocated_position] = task self.tasks_list[allocated_position] = task
llm_logger.info( llm_logger.info(
f"Allocate request: {task.request_id}, " f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, " f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}") f"length of prompt token: {task.prompt_token_ids_len}"
)
allocated_position += 1 allocated_position += 1
processing_task_index += 1 processing_task_index += 1
@@ -325,11 +309,10 @@ class ResourceManager(object):
break break
llm_logger.info( llm_logger.info(
f"Number of allocated requests: {len(tasks)}, number of " f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
f"running requests in worker: {self.real_bsz}") )
llm_logger.info(f"{self.info()}") llm_logger.info(f"{self.info()}")
main_process_metrics.gpu_cache_usage_perc.set( main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
self.get_gpu_cache_usage_perc())
return processed_tasks return processed_tasks
@@ -345,19 +328,15 @@ class ResourceManager(object):
task.seq_lens_decoder = cached_len task.seq_lens_decoder = cached_len
task.prompt_token_ids_len = len(task.prompt_token_ids) task.prompt_token_ids_len = len(task.prompt_token_ids)
def _record_request_cache_info(self, task, common_block_ids, def _record_request_cache_info(self, task, common_block_ids, unique_block_ids, hit_info):
unique_block_ids, hit_info):
""" """
Record the cache information for a given task and its corresponding block IDs. Record the cache information for a given task and its corresponding block IDs.
""" """
cache_block_num = len(common_block_ids) cache_block_num = len(common_block_ids)
no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size \ no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size - cache_block_num)
- cache_block_num)
task.num_cached_tokens = cache_block_num * self.cfg.block_size task.num_cached_tokens = cache_block_num * self.cfg.block_size
task.gpu_cache_token_num = hit_info[ task.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.cfg.block_size
"gpu_cache_blocks"] * self.cfg.block_size task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
task.cpu_cache_token_num = hit_info[
"cpu_cache_blocks"] * self.cfg.block_size
task.cache_info = (cache_block_num, no_cache_block_num) task.cache_info = (cache_block_num, no_cache_block_num)
cached_len = len(common_block_ids) * self.cfg.block_size cached_len = len(common_block_ids) * self.cfg.block_size
@@ -374,9 +353,11 @@ class ResourceManager(object):
Returns: Returns:
str: resource manager info str: resource manager info
""" """
info = f"ResourceManager info, " \ info = (
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \ f"ResourceManager info, "
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, "
f"available_block_num: {self.available_block_num()}, available_batch: {self.available_batch()}" f"available_block_num: {self.available_block_num()}, available_batch: {self.available_batch()}"
)
return info return info
def get_gpu_cache_usage_perc(self): def get_gpu_cache_usage_perc(self):

View File

@@ -94,18 +94,18 @@ class SamplingParams:
bad_words: Optional[List[str]] = None bad_words: Optional[List[str]] = None
@classmethod @classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams": def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
"""Create instance from command line arguments""" """Create instance from command line arguments"""
return cls( return cls(
**{ **{
field.name: field.name: (req_dict[field.name] if field.name in req_dict else field.default)
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls) for field in fields(cls)
}) }
)
@classmethod @classmethod
def from_optional(cls, def from_optional(
cls,
n, n,
best_of, best_of,
presence_penalty, presence_penalty,
@@ -121,16 +121,15 @@ class SamplingParams:
reasoning_max_tokens=None, reasoning_max_tokens=None,
min_tokens=1, min_tokens=1,
logprobs=None, logprobs=None,
bad_words=None) -> "SamplingParams": bad_words=None,
) -> SamplingParams:
"""Create instance from command line arguments""" """Create instance from command line arguments"""
return cls(n=1 if n is None else n, return cls(
n=1 if n is None else n,
best_of=best_of, best_of=best_of,
presence_penalty=presence_penalty presence_penalty=(presence_penalty if presence_penalty is not None else 0.0),
if presence_penalty is not None else 0.0, frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0),
frequency_penalty=frequency_penalty repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0),
if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0, temperature=temperature if temperature is not None else 1.0,
top_p=top_p, top_p=top_p,
top_k=top_k if top_k is not None else 0, top_k=top_k if top_k is not None else 0,
@@ -141,7 +140,8 @@ class SamplingParams:
reasoning_max_tokens=reasoning_max_tokens, reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens, min_tokens=min_tokens,
logprobs=logprobs, logprobs=logprobs,
bad_words=bad_words) bad_words=bad_words,
)
def __post_init__(self): def __post_init__(self):
if self.seed is None: if self.seed is None:
@@ -152,60 +152,44 @@ class SamplingParams:
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int): if not isinstance(self.n, int):
raise ValueError( raise ValueError(f"n must be an int, but is of type {type(self.n)}")
f"n must be an int, but is of type {type(self.n)}")
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
if self.presence_penalty is not None and ( if self.presence_penalty is not None and (not -2.0 <= self.presence_penalty <= 2.0):
not -2.0 <= self.presence_penalty <= 2.0): raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.")
raise ValueError("presence_penalty must be in [-2, 2], got " if self.frequency_penalty is not None and (not -2.0 <= self.frequency_penalty <= 2.0):
f"{self.presence_penalty}.") raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.")
if self.frequency_penalty is not None and (
not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0: if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError( raise ValueError("repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}.")
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature is not None and self.temperature < 0.0: if self.temperature is not None and self.temperature < 0.0:
raise ValueError( raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
f"temperature must be non-negative, got {self.temperature}.")
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0: if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.") raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
# quietly accept -1 as disabled, but prefer 0 # quietly accept -1 as disabled, but prefer 0
if self.top_k < -1: if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, " raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
f"got {self.top_k}.")
if not isinstance(self.top_k, int): if not isinstance(self.top_k, int):
raise TypeError( raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
f"top_k must be an integer, got {type(self.top_k).__name__}")
if self.max_tokens is not None and self.max_tokens < 1: if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError( raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens: if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
raise ValueError( raise ValueError(f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
if self.min_tokens < 0: if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, " raise ValueError(f"min_tokens must be greater than or equal to 0, " f"got {self.min_tokens}.")
f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens: if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError( raise ValueError(
f"min_tokens must be less than or equal to " f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
f"max_tokens={self.max_tokens}, got {self.min_tokens}.") )
if self.logprobs is not None and self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
raise ValueError( raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
f"logprobs must be non-negative, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20: if self.logprobs is not None and self.logprobs > 20:
raise ValueError( raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
if not 0 <= self.seed <= 922337203685477580: if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
f"{self.seed}.")
def update_from_tokenizer(self, tokenizer): def update_from_tokenizer(self, tokenizer):
""" """
@@ -218,6 +202,7 @@ class SamplingParams:
@dataclass @dataclass
class BeamSearchParams: class BeamSearchParams:
"""Beam search parameters for text generation.""" """Beam search parameters for text generation."""
beam_width: int beam_width: int
max_tokens: int max_tokens: int
ignore_eos: bool = False ignore_eos: bool = False

View File

@@ -14,19 +14,25 @@
# limitations under the License. # limitations under the License.
""" """
import uvicorn
import json import json
import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from fastdeploy.utils import FlexibleArgumentParser, api_server_logger, is_port_available
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
is_port_available,
)
app = FastAPI() app = FastAPI()
llm_engine = None llm_engine = None
def init_app(args): def init_app(args):
""" """
init LLMEngine init LLMEngine
@@ -39,7 +45,7 @@ def init_app(args):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!") api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return False return False
api_server_logger.info(f"FastDeploy LLM engine initialized!") api_server_logger.info("FastDeploy LLM engine initialized!")
return True return True
@@ -48,6 +54,7 @@ async def health() -> Response:
"""Health check.""" """Health check."""
return Response(status_code=200) return Response(status_code=200)
@app.post("/generate") @app.post("/generate")
async def generate(request: dict): async def generate(request: dict):
""" """
@@ -64,7 +71,7 @@ async def generate(request: dict):
output = result output = result
except Exception as e: except Exception as e:
# 记录完整的异常堆栈信息 # 记录完整的异常堆栈信息
api_server_logger.error(f"Error during generation: {str(e)}", exc_info=True) api_server_logger.error(f"Error during generation: {e!s}", exc_info=True)
# 返回结构化的错误消息并终止流 # 返回结构化的错误消息并终止流
output = {"error": str(e), "error_type": e.__class__.__name__} output = {"error": str(e), "error_type": e.__class__.__name__}
return output return output
@@ -76,12 +83,14 @@ async def generate(request: dict):
yield f"data: {json.dumps(result)}\n\n" yield f"data: {json.dumps(result)}\n\n"
except Exception as e: except Exception as e:
# 记录完整的异常堆栈信息 # 记录完整的异常堆栈信息
api_server_logger.error(f"Error during generation: {str(e)}", exc_info=True) api_server_logger.error(f"Error during generation: {e!s}", exc_info=True)
# 返回结构化的错误消息并终止流 # 返回结构化的错误消息并终止流
error_msg = {"error": str(e), "error_type": e.__class__.__name__} error_msg = {"error": str(e), "error_type": e.__class__.__name__}
yield f"data: {json.dumps(error_msg)}\n\n" yield f"data: {json.dumps(error_msg)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream") return StreamingResponse(event_generator(), media_type="text/event-stream")
def launch_api_server(args) -> None: def launch_api_server(args) -> None:
""" """
启动http服务 启动http服务
@@ -97,11 +106,13 @@ def launch_api_server(args) -> None:
return return
try: try:
uvicorn.run(app=app, uvicorn.run(
app=app,
host=args.host, host=args.host,
port=args.port, port=args.port,
workers=args.workers, workers=args.workers,
log_level="info") # set log level to error to avoid log log_level="info",
) # set log level to error to avoid log
except Exception as e: except Exception as e:
api_server_logger.error(f"launch sync http server error, {e}") api_server_logger.error(f"launch sync http server error, {e}")

View File

@@ -14,35 +14,45 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Literal, Union, List
from typing_extensions import Required, TypedDict, TypeAlias
from openai.types.chat import ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
from urllib.parse import urlparse
import requests
from copy import deepcopy from copy import deepcopy
from typing import List, Literal, Union
from urllib.parse import urlparse
import requests
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from typing_extensions import Required, TypeAlias, TypedDict
from fastdeploy.input.multimodal.video import VideoMediaIO
from fastdeploy.input.multimodal.image import ImageMediaIO from fastdeploy.input.multimodal.image import ImageMediaIO
from fastdeploy.input.multimodal.video import VideoMediaIO
class VideoURL(TypedDict, total=False): class VideoURL(TypedDict, total=False):
"""Video URL object""" """Video URL object"""
url: Required[str] url: Required[str]
"""Either a URL of the video or the base64 encoded video data""" """Either a URL of the video or the base64 encoded video data"""
class CustomChatCompletionContentPartVideoParam(TypedDict, total=False): class CustomChatCompletionContentPartVideoParam(TypedDict, total=False):
"""Custom Video URL object""" """Custom Video URL object"""
video_url: Required[VideoURL] video_url: Required[VideoURL]
type: Required[Literal["video_url"]] type: Required[Literal["video_url"]]
"""The type of the content type.""" """The type of the content type."""
CustomChatCompletionContentPartParam: TypeAlias = Union[ CustomChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, CustomChatCompletionContentPartVideoParam OpenAIChatCompletionContentPartParam,
CustomChatCompletionContentPartVideoParam,
] ]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Custom User chat message parameter.""" """Custom User chat message parameter."""
@@ -58,11 +68,13 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
Provides the model information to differentiate between participants of the same role. Provides the model information to differentiate between participants of the same role.
""" """
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam]
class MultiModalPartParser(object): class MultiModalPartParser:
"""Multi Modal Part parser""" """Multi Modal Part parser"""
def __init__(self): def __init__(self):
self.image_io = ImageMediaIO() self.image_io = ImageMediaIO()
self.video_io = VideoMediaIO() self.video_io = VideoMediaIO()
@@ -92,6 +104,7 @@ class MultiModalPartParser(object):
localpath = parsed.path localpath = parsed.path
return media_io.load_file(localpath) return media_io.load_file(localpath)
def parse_content_part(mm_parser, part): def parse_content_part(mm_parser, part):
"""only support openai compatible format for now""" """only support openai compatible format for now"""
@@ -120,6 +133,7 @@ def parse_content_part(mm_parser, part):
raise ValueError(f"Unknown content part type: {part_type}") raise ValueError(f"Unknown content part type: {part_type}")
# TODO async # TODO async
# def parse_chat_messages(messages: List[ChatCompletionMessageParam]): # def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
def parse_chat_messages(messages): def parse_chat_messages(messages):

View File

@@ -14,17 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
import zmq
import time import time
from random import randint
import uuid import uuid
import numpy as np import numpy as np
from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.inter_communicator import ZmqClient, IPCSignal
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, EngineError from fastdeploy.utils import EngineError, api_server_logger
class EngineClient: class EngineClient:
@@ -32,23 +30,36 @@ class EngineClient:
EngineClient is a class that handles the communication between the client and the server. EngineClient is a class that handles the communication between the client and the server.
""" """
def __init__(self, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm_per_prompt, mm_processor_kwargs, def __init__(
enable_mm=False, reasoning_parser=None): self,
input_processor = InputPreprocessor(tokenizer, tokenizer,
max_model_len,
tensor_parallel_size,
pid,
limit_mm_per_prompt,
mm_processor_kwargs,
enable_mm=False,
reasoning_parser=None,
):
input_processor = InputPreprocessor(
tokenizer,
reasoning_parser, reasoning_parser,
limit_mm_per_prompt, limit_mm_per_prompt,
mm_processor_kwargs, mm_processor_kwargs,
enable_mm) enable_mm,
)
self.enable_mm = enable_mm self.enable_mm = enable_mm
self.reasoning_parser = reasoning_parser self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor() self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32) self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(name="worker_healthy_live_signal", self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=self.worker_healthy_live_recorded_time_array, array=self.worker_healthy_live_recorded_time_array,
dtype=np.int32, dtype=np.int32,
suffix=pid, suffix=pid,
create=False) create=False,
)
model_weights_status = np.zeros([1], dtype=np.int32) model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal( self.model_weights_status_signal = IPCSignal(
@@ -56,7 +67,8 @@ class EngineClient:
array=model_weights_status, array=model_weights_status,
dtype=np.int32, dtype=np.int32,
suffix=pid, suffix=pid,
create=False) create=False,
)
def create_zmq_client(self, model, mode): def create_zmq_client(self, model, mode):
""" """
@@ -75,7 +87,6 @@ class EngineClient:
if "request_id" not in prompts: if "request_id" not in prompts:
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
prompts["request_id"] = request_id prompts["request_id"] = request_id
query_list = []
if "max_tokens" not in prompts: if "max_tokens" not in prompts:
prompts["max_tokens"] = self.max_model_len - 1 prompts["max_tokens"] = self.max_model_len - 1
@@ -105,8 +116,8 @@ class EngineClient:
if task.get("reasoning_max_tokens", None) is None: if task.get("reasoning_max_tokens", None) is None:
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1) task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
min_tokens = task.get("min_tokens", 1) min_tokens = task.get("min_tokens", 1)
if 'messages' in task: if "messages" in task:
del task['messages'] del task["messages"]
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}") api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")
work_process_metrics.request_params_max_tokens.observe(task["max_tokens"]) work_process_metrics.request_params_max_tokens.observe(task["max_tokens"])
work_process_metrics.prompt_tokens_total.inc(input_ids_len) work_process_metrics.prompt_tokens_total.inc(input_ids_len)
@@ -133,8 +144,7 @@ class EngineClient:
task["preprocess_end_time"] = time.time() task["preprocess_end_time"] = time.time()
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"] preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
api_server_logger.info( api_server_logger.info(
f"Cache request with request_id ({task.get('request_id')}), " f"Cache request with request_id ({task.get('request_id')}), " f"cost {time.time() - preprocess_cost_time}"
f"cost {time.time() - preprocess_cost_time}"
) )
self.vaild_parameters(task) self.vaild_parameters(task)
@@ -153,7 +163,6 @@ class EngineClient:
Validate stream options Validate stream options
""" """
if data.get("n"): if data.get("n"):
if data["n"] != 1: if data["n"] != 1:
raise ValueError("n only support 1.") raise ValueError("n only support 1.")
@@ -168,9 +177,7 @@ class EngineClient:
if data.get("top_p"): if data.get("top_p"):
if data["top_p"] > 1 or data["top_p"] < 0: if data["top_p"] > 1 or data["top_p"] < 0:
raise ValueError( raise ValueError("top_p value can only be defined [0, 1].")
"top_p value can only be defined [0, 1].")
if data.get("frequency_penalty"): if data.get("frequency_penalty"):
if not -2.0 <= data["frequency_penalty"] <= 2.0: if not -2.0 <= data["frequency_penalty"] <= 2.0:
@@ -178,24 +185,18 @@ class EngineClient:
if data.get("temperature"): if data.get("temperature"):
if data["temperature"] < 0: if data["temperature"] < 0:
raise ValueError(f"temperature must be non-negative") raise ValueError("temperature must be non-negative")
if data.get("presence_penalty"): if data.get("presence_penalty"):
if not -2.0 <= data["presence_penalty"] <= 2.0: if not -2.0 <= data["presence_penalty"] <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2]") raise ValueError("presence_penalty must be in [-2, 2]")
if data.get("seed"): if data.get("seed"):
if not 0 <= data["seed"] <= 922337203685477580: if not 0 <= data["seed"] <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580]") raise ValueError("seed must be in [0, 922337203685477580]")
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError("Stream options can only be defined when `stream=True`.")
"Stream options can only be defined when `stream=True`.")
def check_health(self, time_interval_threashold=30): def check_health(self, time_interval_threashold=30):
""" """
@@ -209,7 +210,6 @@ class EngineClient:
return True, "" return True, ""
def is_workers_alive(self): def is_workers_alive(self):
""" """
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.
@@ -220,8 +220,6 @@ class EngineClient:
else: else:
return False, "No model weight enabled" return False, "No model weight enabled"
def update_model_weight(self, timeout=300): def update_model_weight(self, timeout=300):
""" """
Update the model weight by sending a signal to the server. Update the model weight by sending a signal to the server.
@@ -244,8 +242,6 @@ class EngineClient:
time.sleep(1) time.sleep(1)
return True, "" return True, ""
def clear_load_weight(self, timeout=300): def clear_load_weight(self, timeout=300):
""" """
Clear the load weight status. Clear the load weight status.

View File

@@ -28,6 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam # from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server from fastdeploy.utils import llm_logger, retrive_model_from_server
@@ -78,16 +79,14 @@ class LLM:
# Create the Engine # Create the Engine
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
self.default_sampling_params = SamplingParams( self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.max_model_len)
max_tokens=self.llm_engine.cfg.max_model_len)
self.llm_engine.start() self.llm_engine.start()
self.mutex = threading.Lock() self.mutex = threading.Lock()
self.req_output = dict() self.req_output = dict()
self.master_node_ip = self.llm_engine.cfg.master_ip self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread( self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
target=self._receive_output, daemon=True)
self._receive_output_thread.start() self._receive_output_thread.start()
def _check_master(self): def _check_master(self):
@@ -111,15 +110,19 @@ class LLM:
continue continue
self.req_output[request_id].add(result) self.req_output[request_id].add(result)
except Exception as e: except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format( llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
e, str(traceback.format_exc())))
def generate( def generate(
self, self,
prompts: Union[str, list[str], list[int], list[list[int]], prompts: Union[
dict[str, Any], list[dict[str, Any]]], str,
sampling_params: Optional[Union[SamplingParams, list[str],
list[SamplingParams]]] = None, list[int],
list[list[int]],
dict[str, Any],
list[dict[str, Any]],
],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
): ):
""" """
@@ -161,11 +164,9 @@ class LLM:
# sampling_params = None # sampling_params = None
if sampling_params_len != 1 and len(prompts) != sampling_params_len: if sampling_params_len != 1 and len(prompts) != sampling_params_len:
raise ValueError( raise ValueError("prompts and sampling_params must be the same length.")
"prompts and sampling_params must be the same length.")
req_ids = self._add_request(prompts=prompts, req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
sampling_params=sampling_params)
# get output # get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -176,8 +177,7 @@ class LLM:
def chat( def chat(
self, self,
messages: Union[list[Any], list[list[Any]]], messages: Union[list[Any], list[list[Any]]],
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
list[SamplingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None,
): ):
@@ -211,15 +211,16 @@ class LLM:
messages = [messages] messages = [messages]
if sampling_params_len != 1 and len(messages) != sampling_params_len: if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError( raise ValueError("messages and sampling_params must be the same length.")
"messages and sampling_params must be the same length.")
messages_len = len(messages) messages_len = len(messages)
for i in range(messages_len): for i in range(messages_len):
messages[i] = {"messages": messages[i]} messages[i] = {"messages": messages[i]}
req_ids = self._add_request(prompts=messages, req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs) chat_template_kwargs=chat_template_kwargs,
)
# get output # get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm) outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -253,8 +254,7 @@ class LLM:
"prompt": prompts[i], "prompt": prompts[i],
"request_id": request_id, "request_id": request_id,
} }
elif isinstance(prompts[i], list) and isinstance( elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
prompts[i][0], int):
tasks = { tasks = {
"prompt_token_ids": prompts[i], "prompt_token_ids": prompts[i],
"request_id": request_id, "request_id": request_id,
@@ -273,11 +273,8 @@ class LLM:
current_sampling_params = sampling_params current_sampling_params = sampling_params
enable_thinking = None enable_thinking = None
if chat_template_kwargs is not None: if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get( enable_thinking = chat_template_kwargs.get("enable_thinking", None)
"enable_thinking", None) self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
self.llm_engine.add_requests(tasks,
current_sampling_params,
enable_thinking=enable_thinking)
return req_ids return req_ids
def _run_engine(self, req_ids: list[str], use_tqdm: bool): def _run_engine(self, req_ids: list[str], use_tqdm: bool):
@@ -303,8 +300,7 @@ class LLM:
total=num_requests, total=num_requests,
desc="Processed prompts", desc="Processed prompts",
dynamic_ncols=True, dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, " postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
f"output: {0:.2f} toks/s"),
) )
output = [None] * num_requests output = [None] * num_requests
@@ -322,13 +318,11 @@ class LLM:
continue continue
result = self.req_output.pop(req_id) result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response( result = self.llm_engine.data_processor.process_response(result)
result)
output[pos] = result output[pos] = result
finished.append(i) finished.append(i)
llm_logger.debug( llm_logger.debug(f"Request id: {req_id} has been completed.")
"Request id: {} has been completed.".format(req_id))
if use_tqdm: if use_tqdm:
pbar.update(1) pbar.update(1)
@@ -346,24 +340,27 @@ if __name__ == "__main__":
# llm = LLM(model="llama_model") # llm = LLM(model="llama_model")
# output = llm.generate(prompts="who are you", use_tqdm=True) # output = llm.generate(prompts="who are you", use_tqdm=True)
# print(output) # print(output)
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B", llm = LLM(
tensor_parallel_size=2) model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2,
)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30) sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
output = llm.generate(prompts="who are you", output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
print(output)
output = llm.generate(
prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1, max_tokens=50),
use_tqdm=True, use_tqdm=True,
sampling_params=sampling_params) )
print(output) print(output)
output = llm.generate(prompts=["who are you", "what can you do"], output = llm.generate(
sampling_params=SamplingParams(temperature=1, prompts=["who are you", "I miss you"],
max_tokens=50),
use_tqdm=True)
print(output)
output = llm.generate(prompts=["who are you", "I miss you"],
sampling_params=[ sampling_params=[
SamplingParams(temperature=1, max_tokens=50), SamplingParams(temperature=1, max_tokens=50),
SamplingParams(temperature=1, max_tokens=20) SamplingParams(temperature=1, max_tokens=20),
], ],
use_tqdm=True) use_tqdm=True,
)
print(output) print(output)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
import os import os
import threading import threading
import time import time
@@ -24,46 +25,41 @@ import zmq
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST from prometheus_client import CONTENT_TYPE_LATEST
from fastdeploy.metrics.trace_util import inject_to_metadata,instrument
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest, from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
ControlSchedulerRequest,
ErrorResponse, ErrorResponse,
ControlSchedulerRequest) )
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import \ from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
OpenAIServingCompletion from fastdeploy.metrics.metrics import (
from fastdeploy.metrics.metrics import (EXCLUDE_LABELS, EXCLUDE_LABELS,
cleanup_prometheus_files, cleanup_prometheus_files,
get_filtered_metrics, get_filtered_metrics,
main_process_metrics) main_process_metrics,
from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger, )
console_logger, is_port_available, from fastdeploy.metrics.trace_util import inject_to_metadata, instrument
retrive_model_from_server) from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
console_logger,
is_port_available,
retrive_model_from_server,
)
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--port", parser.add_argument("--port", default=8000, type=int, help="port to the http server")
default=8000, parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
type=int,
help="port to the http server")
parser.add_argument("--host",
default="0.0.0.0",
type=str,
help="host to the http server")
parser.add_argument("--workers", default=1, type=int, help="number of workers") parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--metrics-port", parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
default=8001, parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
type=int,
help="port for metrics server")
parser.add_argument("--controller-port",
default=-1,
type=int,
help="port for controller server")
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
args.model = retrive_model_from_server(args.model) args.model = retrive_model_from_server(args.model)
@@ -79,26 +75,18 @@ def load_engine():
if llm_engine is not None: if llm_engine is not None:
return llm_engine return llm_engine
api_server_logger.info( api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args) engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()): if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error( api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
"Failed to initialize FastDeploy LLM engine, service exit now!")
return None return None
api_server_logger.info("FastDeploy LLM engine initialized!\n") api_server_logger.info("FastDeploy LLM engine initialized!\n")
console_logger.info( console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics" console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
) console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
console_logger.info(
f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions"
)
console_logger.info(
f"Launching completion service at http://{args.host}:{args.port}/v1/completions"
)
llm_engine = engine llm_engine = engine
return engine return engine
@@ -111,16 +99,21 @@ async def lifespan(app: FastAPI):
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
if current_process().name != 'MainProcess': if current_process().name != "MainProcess":
pid = os.getppid() pid = os.getppid()
else: else:
pid = os.getpid() pid = os.getpid()
api_server_logger.info(f"{pid}") api_server_logger.info(f"{pid}")
engine_client = EngineClient(args.tokenizer, args.max_model_len, engine_client = EngineClient(
args.tensor_parallel_size, pid, args.tokenizer,
args.max_model_len,
args.tensor_parallel_size,
pid,
args.limit_mm_per_prompt, args.limit_mm_per_prompt,
args.mm_processor_kwargs, args.enable_mm, args.mm_processor_kwargs,
args.reasoning_parser) args.enable_mm,
args.reasoning_parser,
)
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip) chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip) completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip)
@@ -134,6 +127,7 @@ async def lifespan(app: FastAPI):
try: try:
engine_client.zmq_client.close() engine_client.zmq_client.close()
from prometheus_client import multiprocess from prometheus_client import multiprocess
multiprocess.mark_process_dead(os.getpid()) multiprocess.mark_process_dead(os.getpid())
api_server_logger.info(f"Closing metrics client pid: {pid}") api_server_logger.info(f"Closing metrics client pid: {pid}")
except Exception as e: except Exception as e:
@@ -187,11 +181,7 @@ async def list_all_routes():
if route.path.startswith("/v1"): if route.path.startswith("/v1"):
methods = sorted(route.methods) methods = sorted(route.methods)
tags = getattr(route, "tags", []) or [] tags = getattr(route, "tags", []) or []
routes_info.append({ routes_info.append({"path": route.path, "methods": methods, "tags": tags})
"path": route.path,
"methods": methods,
"tags": tags
})
return {"routes": routes_info} return {"routes": routes_info}
@@ -209,15 +199,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
if app.state.dynamic_load_weight: if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive() status, msg = app.state.engine_client.is_workers_alive()
if not status: if not status:
return JSONResponse( return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
content={"error": "Worker Service Not Healthy"},
status_code=304)
inject_to_metadata(request) inject_to_metadata(request)
generator = await app.state.chat_handler.create_chat_completion(request) generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(), status_code=generator.code)
status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse): elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@@ -233,14 +220,11 @@ async def create_completion(request: CompletionRequest):
if app.state.dynamic_load_weight: if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive() status, msg = app.state.engine_client.is_workers_alive()
if not status: if not status:
return JSONResponse( return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
content={"error": "Worker Service Not Healthy"},
status_code=304)
generator = await app.state.completion_handler.create_completion(request) generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(), status_code=generator.code)
status_code=generator.code)
elif isinstance(generator, CompletionResponse): elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
@@ -258,8 +242,7 @@ def update_model_weight(request: Request) -> Response:
return Response(content=msg, status_code=404) return Response(content=msg, status_code=404)
return Response(status_code=200) return Response(status_code=200)
else: else:
return Response(content="Dynamic Load Weight Disabled.", return Response(content="Dynamic Load Weight Disabled.", status_code=404)
status_code=404)
@app.get("/clear_load_weight") @app.get("/clear_load_weight")
@@ -273,8 +256,7 @@ def clear_load_weight(request: Request) -> Response:
return Response(content=msg, status_code=404) return Response(content=msg, status_code=404)
return Response(status_code=200) return Response(status_code=200)
else: else:
return Response(content="Dynamic Load Weight Disabled.", return Response(content="Dynamic Load Weight Disabled.", status_code=404)
status_code=404)
def launch_api_server() -> None: def launch_api_server() -> None:
@@ -284,16 +266,17 @@ def launch_api_server() -> None:
if not is_port_available(args.host, args.port): if not is_port_available(args.host, args.port):
raise Exception(f"The parameter `port`:{args.port} is already in use.") raise Exception(f"The parameter `port`:{args.port} is already in use.")
api_server_logger.info( api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}")
f"launch Fastdeploy api server... port: {args.port}")
api_server_logger.info(f"args: {args.__dict__}") api_server_logger.info(f"args: {args.__dict__}")
try: try:
uvicorn.run(app="fastdeploy.entrypoints.openai.api_server:app", uvicorn.run(
app="fastdeploy.entrypoints.openai.api_server:app",
host=args.host, host=args.host,
port=args.port, port=args.port,
workers=args.workers, workers=args.workers,
log_level="info") # set log level to error to avoid log log_level="info",
) # set log level to error to avoid log
except Exception as e: except Exception as e:
api_server_logger.error(f"launch sync http server error, {e}") api_server_logger.error(f"launch sync http server error, {e}")
@@ -308,8 +291,8 @@ async def metrics():
""" """
metrics_text = get_filtered_metrics( metrics_text = get_filtered_metrics(
EXCLUDE_LABELS, EXCLUDE_LABELS,
extra_register_func=lambda reg: main_process_metrics.register_all( extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers),
reg, workers=args.workers)) )
return Response(metrics_text, media_type=CONTENT_TYPE_LATEST) return Response(metrics_text, media_type=CONTENT_TYPE_LATEST)
@@ -318,23 +301,17 @@ def run_metrics_server():
run metrics server run metrics server
""" """
uvicorn.run(metrics_app, uvicorn.run(metrics_app, host="0.0.0.0", port=args.metrics_port, log_level="error")
host="0.0.0.0",
port=args.metrics_port,
log_level="error")
def launch_metrics_server(): def launch_metrics_server():
"""Metrics server running the sub thread""" """Metrics server running the sub thread"""
if not is_port_available(args.host, args.metrics_port): if not is_port_available(args.host, args.metrics_port):
raise Exception( raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.")
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
)
prom_dir = cleanup_prometheus_files(True) prom_dir = cleanup_prometheus_files(True)
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
metrics_server_thread = threading.Thread(target=run_metrics_server, metrics_server_thread = threading.Thread(target=run_metrics_server, daemon=True)
daemon=True)
metrics_server_thread.start() metrics_server_thread.start()
time.sleep(1) time.sleep(1)
@@ -375,7 +352,8 @@ def control_scheduler(request: ControlSchedulerRequest):
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config): if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
llm_engine.scheduler.update_config( llm_engine.scheduler.update_config(
load_shards_num=request.load_shards_num, load_shards_num=request.load_shards_num,
reallocate=request.reallocate_shard) reallocate=request.reallocate_shard,
)
else: else:
content.message = "This scheduler doesn't support the `update_config()` method." content.message = "This scheduler doesn't support the `update_config()` method."
content.code = 400 content.code = 400
@@ -388,10 +366,12 @@ def run_controller_server():
""" """
run controller server run controller server
""" """
uvicorn.run(controller_app, uvicorn.run(
controller_app,
host="0.0.0.0", host="0.0.0.0",
port=args.controller_port, port=args.controller_port,
log_level="error") log_level="error",
)
def launch_controller_server(): def launch_controller_server():
@@ -400,12 +380,9 @@ def launch_controller_server():
return return
if not is_port_available(args.host, args.controller_port): if not is_port_available(args.host, args.controller_port):
raise Exception( raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.")
f"The parameter `controller_port`:{args.controller_port} is already in use."
)
controller_server_thread = threading.Thread(target=run_controller_server, controller_server_thread = threading.Thread(target=run_controller_server, daemon=True)
daemon=True)
controller_server_thread.start() controller_server_thread.start()
time.sleep(1) time.sleep(1)

View File

@@ -30,6 +30,7 @@ class ErrorResponse(BaseModel):
""" """
Error response from OpenAI API. Error response from OpenAI API.
""" """
object: str = "error" object: str = "error"
message: str message: str
code: int code: int
@@ -39,6 +40,7 @@ class PromptTokenUsageInfo(BaseModel):
""" """
Prompt-related token usage info. Prompt-related token usage info.
""" """
cached_tokens: Optional[int] = None cached_tokens: Optional[int] = None
@@ -46,6 +48,7 @@ class UsageInfo(BaseModel):
""" """
Usage info for a single request. Usage info for a single request.
""" """
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
@@ -56,6 +59,7 @@ class FunctionCall(BaseModel):
""" """
Function call. Function call.
""" """
name: str name: str
arguments: str arguments: str
@@ -64,6 +68,7 @@ class ToolCall(BaseModel):
""" """
Tool call. Tool call.
""" """
id: str = None id: str = None
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: FunctionCall function: FunctionCall
@@ -74,6 +79,7 @@ class DeltaFunctionCall(BaseModel):
""" """
Delta function call. Delta function call.
""" """
name: Optional[str] = None name: Optional[str] = None
arguments: Optional[str] = None arguments: Optional[str] = None
@@ -83,6 +89,7 @@ class DeltaToolCall(BaseModel):
""" """
Delta tool call. Delta tool call.
""" """
id: Optional[str] = None id: Optional[str] = None
type: Optional[Literal["function"]] = None type: Optional[Literal["function"]] = None
index: int index: int
@@ -93,6 +100,7 @@ class FunctionDefinition(BaseModel):
""" """
Function definition. Function definition.
""" """
name: str name: str
description: Optional[str] = None description: Optional[str] = None
parameters: Optional[dict[str, Any]] = None parameters: Optional[dict[str, Any]] = None
@@ -102,6 +110,7 @@ class ChatCompletionToolsParam(BaseModel):
""" """
Chat completion tools parameter. Chat completion tools parameter.
""" """
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: FunctionDefinition function: FunctionDefinition
@@ -110,6 +119,7 @@ class ChatMessage(BaseModel):
""" """
Chat message. Chat message.
""" """
role: str role: str
content: str content: str
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
@@ -120,6 +130,7 @@ class ChatCompletionResponseChoice(BaseModel):
""" """
Chat completion response choice. Chat completion response choice.
""" """
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
@@ -130,6 +141,7 @@ class ChatCompletionResponse(BaseModel):
""" """
Chat completion response. Chat completion response.
""" """
id: str id: str
object: str = "chat.completion" object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@@ -137,26 +149,32 @@ class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
usage: UsageInfo usage: UsageInfo
class LogProbEntry(BaseModel): class LogProbEntry(BaseModel):
""" """
Log probability entry. Log probability entry.
""" """
token: str token: str
logprob: float logprob: float
bytes: Optional[List[int]] = None bytes: Optional[List[int]] = None
top_logprobs: Optional[List["LogProbEntry"]] = None top_logprobs: Optional[List[LogProbEntry]] = None
class LogProbs(BaseModel): class LogProbs(BaseModel):
""" """
LogProbs. LogProbs.
""" """
content: Optional[List[LogProbEntry]] = None content: Optional[List[LogProbEntry]] = None
refusal: Optional[Union[str, None]] = None refusal: Optional[Union[str, None]] = None
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
""" """
Delta message for chat completion stream response. Delta message for chat completion stream response.
""" """
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
token_ids: Optional[List[int]] = None token_ids: Optional[List[int]] = None
@@ -168,6 +186,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
""" """
Chat completion response choice for stream response. Chat completion response choice for stream response.
""" """
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
@@ -179,6 +198,7 @@ class ChatCompletionStreamResponse(BaseModel):
""" """
Chat completion response for stream response. Chat completion response for stream response.
""" """
id: str id: str
object: str = "chat.completion.chunk" object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@@ -191,6 +211,7 @@ class CompletionResponseChoice(BaseModel):
""" """
Completion response choice. Completion response choice.
""" """
index: int index: int
text: str text: str
token_ids: Optional[List[int]] = None token_ids: Optional[List[int]] = None
@@ -205,6 +226,7 @@ class CompletionResponse(BaseModel):
""" """
Completion response. Completion response.
""" """
id: str id: str
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@@ -217,6 +239,7 @@ class CompletionResponseStreamChoice(BaseModel):
""" """
Completion response choice for stream response. Completion response choice for stream response.
""" """
index: int index: int
text: str text: str
arrival_time: float = None arrival_time: float = None
@@ -231,6 +254,7 @@ class CompletionStreamResponse(BaseModel):
""" """
Completion response for stream response. Completion response for stream response.
""" """
id: str id: str
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@@ -243,6 +267,7 @@ class StreamOptions(BaseModel):
""" """
Stream options. Stream options.
""" """
include_usage: Optional[bool] = True include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = False continuous_usage_stats: Optional[bool] = False
@@ -251,9 +276,9 @@ class StructuralTag(BaseModel):
""" """
Structural tag. Structural tag.
""" """
begin: str begin: str
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
alias="schema")
end: str end: str
@@ -261,9 +286,10 @@ class JsonSchemaResponseFormat(BaseModel):
""" """
Json schema for ResponseFormat. Json schema for ResponseFormat.
""" """
name: str name: str
description: Optional[str] = None description: Optional[str] = None
json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
strict: Optional[bool] = None strict: Optional[bool] = None
@@ -271,6 +297,7 @@ class StructuralTagResponseFormat(BaseModel):
""" """
Structural tag for ResponseFormat. Structural tag for ResponseFormat.
""" """
type: Literal["structural_tag"] type: Literal["structural_tag"]
structures: list[StructuralTag] structures: list[StructuralTag]
triggers: list[str] triggers: list[str]
@@ -280,6 +307,7 @@ class ResponseFormat(BaseModel):
""" """
response_format type. response_format type.
""" """
type: Literal["text", "json_object", "json_schema"] type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None json_schema: Optional[JsonSchemaResponseFormat] = None
@@ -291,6 +319,7 @@ class CompletionRequest(BaseModel):
""" """
Completion request to the engine. Completion request to the engine.
""" """
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = "default" model: Optional[str] = "default"
@@ -333,7 +362,7 @@ class CompletionRequest(BaseModel):
""" """
req_dict = {} req_dict = {}
if request_id is not None: if request_id is not None:
req_dict['request_id'] = request_id req_dict["request_id"] = request_id
for key, value in self.dict().items(): for key, value in self.dict().items():
if value is not None: if value is not None:
req_dict[key] = value req_dict[key] = value
@@ -341,7 +370,7 @@ class CompletionRequest(BaseModel):
for key, value in self.suffix.items(): for key, value in self.suffix.items():
req_dict[key] = value req_dict[key] = value
if prompt is not None: if prompt is not None:
req_dict['prompt'] = prompt req_dict["prompt"] = prompt
if isinstance(prompt[0], int): if isinstance(prompt[0], int):
req_dict["prompt_token_ids"] = prompt req_dict["prompt_token_ids"] = prompt
@@ -363,8 +392,11 @@ class CompletionRequest(BaseModel):
req_dict["guided_json_object"] = guided_json_object req_dict["guided_json_object"] = guided_json_object
guided_schema = [ guided_schema = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar", "guided_json",
"structural_tag" "guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
] ]
for key in guided_schema: for key in guided_schema:
item = getattr(self, key, None) item = getattr(self, key, None)
@@ -380,15 +412,16 @@ class CompletionRequest(BaseModel):
Validate stream options Validate stream options
""" """
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError("Stream options can only be defined when `stream=True`.")
"Stream options can only be defined when `stream=True`.")
guided_count = sum([ guided_count = sum(
[
"guided_json" in data and data["guided_json"] is not None, "guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None, "guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None, "guided_choice" in data and data["guided_choice"] is not None,
"guided_grammar" in data and data["guided_grammar"] is not None "guided_grammar" in data and data["guided_grammar"] is not None,
]) ]
)
if guided_count > 1: if guided_count > 1:
raise ValueError( raise ValueError(
@@ -403,6 +436,7 @@ class ChatCompletionRequest(BaseModel):
""" """
Chat completion request to the engine. Chat completion request to the engine.
""" """
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
messages: Union[List[Any], List[int]] messages: Union[List[Any], List[int]]
@@ -414,8 +448,8 @@ class ChatCompletionRequest(BaseModel):
# remove max_tokens when field is removed from OpenAI API # remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default=None, default=None,
deprecated= deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
'max_tokens is deprecated in favor of the max_completion_tokens field') )
max_completion_tokens: Optional[int] = None max_completion_tokens: Optional[int] = None
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = None presence_penalty: Optional[float] = None
@@ -451,7 +485,7 @@ class ChatCompletionRequest(BaseModel):
""" """
req_dict = {} req_dict = {}
if request_id is not None: if request_id is not None:
req_dict['request_id'] = request_id req_dict["request_id"] = request_id
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
@@ -483,17 +517,18 @@ class ChatCompletionRequest(BaseModel):
self.guided_json = json_schema self.guided_json = json_schema
elif self.response_format.type == "structural_tag": elif self.response_format.type == "structural_tag":
structural_tag = self.response_format structural_tag = self.response_format
assert structural_tag is not None and isinstance( assert structural_tag is not None and isinstance(structural_tag, StructuralTagResponseFormat)
structural_tag, StructuralTagResponseFormat) self.structural_tag = json.dumps(structural_tag.model_dump(by_alias=True))
self.structural_tag = json.dumps(
structural_tag.model_dump(by_alias=True))
if guided_json_object: if guided_json_object:
req_dict["guided_json_object"] = guided_json_object req_dict["guided_json_object"] = guided_json_object
guided_schema = [ guided_schema = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar", "guided_json",
"structural_tag" "guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
] ]
for key in guided_schema: for key in guided_schema:
item = getattr(self, key, None) item = getattr(self, key, None)
@@ -509,16 +544,17 @@ class ChatCompletionRequest(BaseModel):
Validate stream options Validate stream options
""" """
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError("Stream options can only be defined when `stream=True`.")
"Stream options can only be defined when `stream=True`.")
guided_count = sum([ guided_count = sum(
[
"guided_json" in data and data["guided_json"] is not None, "guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None, "guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None, "guided_choice" in data and data["guided_choice"] is not None,
"guided_grammar" in data and data["guided_grammar"] is not None, "guided_grammar" in data and data["guided_grammar"] is not None,
"structural_tag" in data and data["structural_tag"] is not None "structural_tag" in data and data["structural_tag"] is not None,
]) ]
)
if guided_count > 1: if guided_count > 1:
raise ValueError( raise ValueError(
@@ -537,9 +573,7 @@ class ChatCompletionRequest(BaseModel):
raise ValueError("`top_logprobs` must be a positive value.") raise ValueError("`top_logprobs` must be a positive value.")
if top_logprobs > 0 and not data.get("logprobs"): if top_logprobs > 0 and not data.get("logprobs"):
raise ValueError( raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data return data
@@ -548,6 +582,7 @@ class ControlSchedulerRequest(BaseModel):
""" """
Control scheduler request to the engine. Control scheduler request to the engine.
""" """
reset: Optional[bool] = False reset: Optional[bool] = False
load_shards_num: Optional[int] = None load_shards_num: Optional[int] = None
reallocate_shard: Optional[bool] = False reallocate_shard: Optional[bool] = False

View File

@@ -15,21 +15,29 @@
""" """
import asyncio import asyncio
import json
import time import time
import traceback import traceback
import uuid import uuid
from typing import List, Optional from typing import List, Optional
import msgpack
import aiozmq import aiozmq
import msgpack
from aiozmq import zmq from aiozmq import zmq
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponse,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionResponseChoice,
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo) ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorResponse,
LogProbEntry,
LogProbs,
PromptTokenUsageInfo,
UsageInfo,
)
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, get_host_ip from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.worker.output import LogprobsLists from fastdeploy.worker.output import LogprobsLists
@@ -53,10 +61,7 @@ class OpenAIServingChat:
return True return True
return False return False
async def create_chat_completion( async def create_chat_completion(self, request: ChatCompletionRequest):
self,
request: ChatCompletionRequest
):
""" """
Create a new chat completion using the specified parameters. Create a new chat completion using the specified parameters.
""" """
@@ -81,16 +86,10 @@ class OpenAIServingChat:
del current_req_dict del current_req_dict
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids)
request, request_id,
request.model,
prompt_token_ids)
else: else:
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(request, request_id, request.model, prompt_token_ids)
request, request_id,
request.model,
prompt_token_ids)
except Exception as e: except Exception as e:
return ErrorResponse(code=400, message=str(e)) return ErrorResponse(code=400, message=str(e))
@@ -106,7 +105,7 @@ class OpenAIServingChat:
request: ChatCompletionRequest, request: ChatCompletionRequest,
request_id: str, request_id: str,
model_name: str, model_name: str,
prompt_token_ids: list() prompt_token_ids: list(),
): ):
""" """
Streaming chat completion generator. Streaming chat completion generator.
@@ -135,14 +134,11 @@ class OpenAIServingChat:
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[], choices=[],
model=model_name model=model_name,
) )
try: try:
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
zmq.DEALER, dealer.write([b"", request_id.encode("utf-8")])
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer.write([b"", request_id.encode('utf-8')])
choices = [] choices = []
current_waiting_time = 0 current_waiting_time = 0
if request.metadata is not None: if request.metadata is not None:
@@ -171,20 +167,29 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output) res,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
if res['metrics']['first_token_time'] is not None: if res["metrics"]["first_token_time"] is not None:
arrival_time = res['metrics']['first_token_time'] arrival_time = res["metrics"]["first_token_time"]
inference_start_time = res['metrics']['inference_start_time'] inference_start_time = res["metrics"]["inference_start_time"]
else: else:
arrival_time = res['metrics']['arrival_time'] - inference_start_time arrival_time = res["metrics"]["arrival_time"] - inference_start_time
if first_iteration: if first_iteration:
num_prompt_tokens = len(prompt_token_ids) num_prompt_tokens = len(prompt_token_ids)
num_cached_tokens = res.get("num_cached_tokens", 0) num_cached_tokens = res.get("num_cached_tokens", 0)
for i in range(num_choices): for i in range(num_choices):
choice = ChatCompletionResponseStreamChoice( choice = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None) delta=DeltaMessage(
role="assistant",
content="",
reasoning_content="",
tool_calls=None,
),
) )
if request.metadata is not None and request.metadata.get("training", False): if request.metadata is not None and request.metadata.get("training", False):
choice.delta.token_ids = prompt_token_ids choice.delta.token_ids = prompt_token_ids
@@ -193,14 +198,14 @@ class OpenAIServingChat:
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[choice], choices=[choice],
model=model_name model=model_name,
) )
if include_continuous_usage: if include_continuous_usage:
chunk.usage = UsageInfo( chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=0, completion_tokens=0,
total_tokens=num_prompt_tokens, total_tokens=num_prompt_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens) prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens),
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
first_iteration = False first_iteration = False
@@ -222,24 +227,32 @@ class OpenAIServingChat:
) )
previous_num_tokens += len(output["token_ids"]) previous_num_tokens += len(output["token_ids"])
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \ delta_message = DeltaMessage(
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", [])) content=delta_text,
reasoning_content=output.get("reasoning_content"),
token_ids=output.get("token_ids"),
tool_calls=output.get("tool_call_content", []),
)
choice = ChatCompletionResponseStreamChoice( choice = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=delta_message, delta=delta_message,
logprobs=logprobs_res, logprobs=logprobs_res,
arrival_time=arrival_time arrival_time=arrival_time,
) )
if res["finished"]: if res["finished"]:
num_choices -= 1 num_choices -= 1
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"]) work_process_metrics.e2e_request_latency.observe(
time.time() - res["metrics"]["request_start_time"]
)
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens: if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop" choice.finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \ if (
output.get("finish_reason", "") == "tool_calls": self.engine_client.reasoning_parser == "ernie_x1"
and output.get("finish_reason", "") == "tool_calls"
):
choice.finish_reason = "tool_calls" choice.finish_reason = "tool_calls"
else: else:
choice.finish_reason = "length" choice.finish_reason = "length"
@@ -253,7 +266,7 @@ class OpenAIServingChat:
chunk.usage = UsageInfo( chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=previous_num_tokens, completion_tokens=previous_num_tokens,
total_tokens=num_prompt_tokens + previous_num_tokens total_tokens=num_prompt_tokens + previous_num_tokens,
) )
choices.append(choice) choices.append(choice)
@@ -267,13 +280,12 @@ class OpenAIServingChat:
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = [] choices = []
if include_usage: if include_usage:
completion_tokens = previous_num_tokens completion_tokens = previous_num_tokens
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens total_tokens=num_prompt_tokens + completion_tokens,
) )
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
@@ -281,7 +293,7 @@ class OpenAIServingChat:
created=created_time, created=created_time,
choices=[], choices=[],
model=model_name, model=model_name,
usage=usage usage=usage,
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
@@ -297,7 +309,7 @@ class OpenAIServingChat:
request: ChatCompletionRequest, request: ChatCompletionRequest,
request_id: str, request_id: str,
model_name: str, model_name: str,
prompt_token_ids: list() prompt_token_ids: list(),
): ):
""" """
Full chat completion generator. Full chat completion generator.
@@ -307,11 +319,8 @@ class OpenAIServingChat:
enable_thinking = None enable_thinking = None
include_stop_str_in_output = False include_stop_str_in_output = False
try: try:
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
zmq.DEALER, dealer.write([b"", request_id.encode("utf-8")])
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer.write([b"", request_id.encode('utf-8')])
final_res = None final_res = None
previous_num_tokens = 0 previous_num_tokens = 0
current_waiting_time = 0 current_waiting_time = 0
@@ -340,7 +349,11 @@ class OpenAIServingChat:
enable_thinking = request.metadata.get("enable_thinking") enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False) include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict( data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output) data,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
# api_server_logger.debug(f"Client {request_id} received: {data}") # api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"]) previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response # The logprob for handling the response
@@ -375,26 +388,23 @@ class OpenAIServingChat:
content=output["text"], content=output["text"],
reasoning_content=output.get("reasoning_content"), reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"), tool_calls=output.get("tool_call_content"),
token_ids=output.get("token_ids") token_ids=output.get("token_ids"),
) )
logprobs_full_res = None logprobs_full_res = None
if logprob_contents: if logprob_contents:
logprobs_full_res = LogProbs( logprobs_full_res = LogProbs(content=logprob_contents)
content=logprob_contents
)
choice = ChatCompletionResponseChoice( choice = ChatCompletionResponseChoice(
index=0, index=0,
message=message, message=message,
logprobs=logprobs_full_res, logprobs=logprobs_full_res,
finish_reason=None finish_reason=None,
) )
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens: if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop" choice.finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \ if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
output.get("finish_reason", "") == "tool_calls":
choice.finish_reason = "tool_calls" choice.finish_reason = "tool_calls"
else: else:
choice.finish_reason = "length" choice.finish_reason = "length"
@@ -409,7 +419,7 @@ class OpenAIServingChat:
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0)) prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0)),
) )
work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"]) work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"])
return ChatCompletionResponse( return ChatCompletionResponse(
@@ -417,7 +427,7 @@ class OpenAIServingChat:
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=choices, choices=choices,
usage=usage usage=usage,
) )
def build_logprobs_response( def build_logprobs_response(
@@ -454,8 +464,9 @@ class OpenAIServingChat:
# Construct the candidate token structure (LogProbEntry) of topk # Construct the candidate token structure (LogProbEntry) of topk
top_logprob_entries: List[LogProbEntry] = [] top_logprob_entries: List[LogProbEntry] = []
for tid, lp in zip(topk_token_ids, topk_logprobs): for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response([tid], token_str = self.engine_client.data_processor.process_logprob_response(
clean_up_tokenization_spaces=False) [tid], clean_up_tokenization_spaces=False
)
# token_bytes = token_str.encode("utf-8", errors="replace") # token_bytes = token_str.encode("utf-8", errors="replace")
entry = LogProbEntry( entry = LogProbEntry(
token=token_str, token=token_str,
@@ -468,7 +479,7 @@ class OpenAIServingChat:
token=top_logprob_entries[0].token, token=top_logprob_entries[0].token,
logprob=top_logprob_entries[0].logprob, logprob=top_logprob_entries[0].logprob,
bytes=top_logprob_entries[0].bytes, bytes=top_logprob_entries[0].bytes,
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates top_logprobs=top_logprob_entries[1:], # Here are the complete topk candidates
) )
return LogProbs(content=[sampled_entry]) return LogProbs(content=[sampled_entry])

View File

@@ -15,33 +15,25 @@
""" """
import asyncio import asyncio
import time
import uuid
from typing import List
import aiozmq import aiozmq
import json
import msgpack import msgpack
from aiozmq import zmq from aiozmq import zmq
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast, TypeVar, List
import uuid
from fastapi import Request
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import ( from fastdeploy.entrypoints.openai.protocol import (
ErrorResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionStreamResponse,
CompletionResponseStreamChoice,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
UsageInfo, UsageInfo,
DeltaToolCall,
DeltaFunctionCall,
ToolCall,
FunctionCall
) )
from fastdeploy.utils import api_server_logger, get_host_ip from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.engine.request import RequestOutput
class OpenAIServingCompletion: class OpenAIServingCompletion:
@@ -105,9 +97,7 @@ class OpenAIServingCompletion:
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt) current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
try: try:
current_req_dict["arrival_time"] = time.time() current_req_dict["arrival_time"] = time.time()
prompt_batched_token_ids.append( prompt_batched_token_ids.append(self.engine_client.format_and_add_data(current_req_dict))
self.engine_client.format_and_add_data(current_req_dict)
)
except Exception as e: except Exception as e:
return ErrorResponse(message=str(e), code=400) return ErrorResponse(message=str(e), code=400)
@@ -120,7 +110,7 @@ class OpenAIServingCompletion:
request_id=request_id, request_id=request_id,
created_time=created_time, created_time=created_time,
model_name=request.model, model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids prompt_batched_token_ids=prompt_batched_token_ids,
) )
else: else:
try: try:
@@ -130,7 +120,7 @@ class OpenAIServingCompletion:
request_id=request_id, request_id=request_id,
created_time=created_time, created_time=created_time,
model_name=request.model, model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids prompt_batched_token_ids=prompt_batched_token_ids,
) )
except Exception as e: except Exception as e:
return ErrorResponse(code=400, message=str(e)) return ErrorResponse(code=400, message=str(e))
@@ -138,7 +128,6 @@ class OpenAIServingCompletion:
except Exception as e: except Exception as e:
return ErrorResponse(message=str(e), code=400) return ErrorResponse(message=str(e), code=400)
async def completion_full_generator( async def completion_full_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
@@ -146,7 +135,7 @@ class OpenAIServingCompletion:
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
prompt_batched_token_ids: list() prompt_batched_token_ids: list(),
): ):
""" """
Process the full completion request with multiple choices. Process the full completion request with multiple choices.
@@ -155,10 +144,7 @@ class OpenAIServingCompletion:
try: try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)] request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
# create dealer # create dealer
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
for rid in request_ids: for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")]) dealer.write([b"", rid.encode("utf-8")])
@@ -186,8 +172,7 @@ class OpenAIServingCompletion:
if data.get("error_code", 200) != 200: if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(data, stream=False)
data, stream=False)
output_tokens[rid] += len(data["outputs"]["token_ids"]) output_tokens[rid] += len(data["outputs"]["token_ids"])
if data.get("finished", False): if data.get("finished", False):
data["output_token_ids"] = output_tokens[rid] data["output_token_ids"] = output_tokens[rid]
@@ -201,18 +186,15 @@ class OpenAIServingCompletion:
request_id=request_id, request_id=request_id,
created_time=created_time, created_time=created_time,
model_name=model_name, model_name=model_name,
prompt_batched_token_ids=prompt_batched_token_ids prompt_batched_token_ids=prompt_batched_token_ids,
) )
except Exception as e: except Exception as e:
api_server_logger.error( api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True)
f"Error in completion_full_generator: {e}", exc_info=True
)
raise raise
finally: finally:
if dealer is not None: if dealer is not None:
dealer.close() dealer.close()
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
@@ -220,20 +202,17 @@ class OpenAIServingCompletion:
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
prompt_batched_token_ids: list() prompt_batched_token_ids: list(),
): ):
""" """
Process the stream completion request. Process the stream completion request.
""" """
try: try:
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
for i in range(num_choices): for i in range(num_choices):
req_id = f"{request_id}-{i}" req_id = f"{request_id}-{i}"
dealer.write([b"", req_id.encode('utf-8')]) # 发送多路请求 dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求
output_tokens = [0] * num_choices output_tokens = [0] * num_choices
inference_start_time = [0] * num_choices inference_start_time = [0] * num_choices
first_iteration = [True] * num_choices first_iteration = [True] * num_choices
@@ -245,7 +224,7 @@ class OpenAIServingCompletion:
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=choices choices=choices,
) )
current_waiting_time = 0 current_waiting_time = 0
@@ -264,7 +243,6 @@ class OpenAIServingCompletion:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
response = msgpack.unpackb(raw_data[-1]) response = msgpack.unpackb(raw_data[-1])
for res in response: for res in response:
idx = int(res["request_id"].split("-")[-1]) idx = int(res["request_id"].split("-")[-1])
@@ -277,39 +255,43 @@ class OpenAIServingCompletion:
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=[CompletionResponseStreamChoice( choices=[
CompletionResponseStreamChoice(
index=idx, index=idx,
text="", text="",
token_ids=list(prompt_batched_token_ids[idx]) token_ids=list(prompt_batched_token_ids[idx]),
)] )
],
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
first_iteration[idx] = False first_iteration[idx] = False
self.engine_client.data_processor.process_response_dict(res, stream=True)
self.engine_client.data_processor.process_response_dict( if res["metrics"].get("first_token_time") is not None:
res, stream=True) arrival_time = res["metrics"]["first_token_time"]
if res['metrics'].get('first_token_time') is not None: inference_start_time[idx] = res["metrics"]["inference_start_time"]
arrival_time = res['metrics']['first_token_time']
inference_start_time[idx] = res['metrics']['inference_start_time']
else: else:
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx] arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
output = res["outputs"] output = res["outputs"]
choices.append(CompletionResponseStreamChoice( choices.append(
CompletionResponseStreamChoice(
index=idx, index=idx,
text=output["text"], text=output["text"],
token_ids=output.get("token_ids"), token_ids=output.get("token_ids"),
tool_calls=output.get("tool_call_content"), tool_calls=output.get("tool_call_content"),
reasoning_content=output.get("reasoning_content"), reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time arrival_time=arrival_time,
)) )
)
if res["finished"]: if res["finished"]:
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens: if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
chunk.choices[0].finish_reason = "stop" chunk.choices[0].finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \ if (
output.get("finish_reason", "") == "tool_calls": self.engine_client.reasoning_parser == "ernie_x1"
and output.get("finish_reason", "") == "tool_calls"
):
chunk.choices[0].finish_reason = "tool_calls" chunk.choices[0].finish_reason = "tool_calls"
else: else:
chunk.choices[0].finish_reason = "length" chunk.choices[0].finish_reason = "length"
@@ -321,12 +303,11 @@ class OpenAIServingCompletion:
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=choices choices=choices,
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = [] choices = []
if res["finished"]: if res["finished"]:
num_choices -= 1 num_choices -= 1
if getattr(request, "stream_options", None) and request.stream_options.include_usage: if getattr(request, "stream_options", None) and request.stream_options.include_usage:
@@ -337,8 +318,8 @@ class OpenAIServingCompletion:
choices=[], choices=[],
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=len(prompt_batched_token_ids[idx]), prompt_tokens=len(prompt_batched_token_ids[idx]),
completion_tokens=output_tokens[idx] completion_tokens=output_tokens[idx],
) ),
) )
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
if choices: if choices:
@@ -346,7 +327,6 @@ class OpenAIServingCompletion:
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = [] choices = []
except Exception as e: except Exception as e:
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n" yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
finally: finally:
@@ -355,7 +335,6 @@ class OpenAIServingCompletion:
dealer.close() dealer.close()
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
def request_output_to_completion_response( def request_output_to_completion_response(
self, self,
final_res_batch: List[RequestOutput], final_res_batch: List[RequestOutput],
@@ -363,7 +342,7 @@ class OpenAIServingCompletion:
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
prompt_batched_token_ids: list() prompt_batched_token_ids: list(),
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
@@ -389,12 +368,13 @@ class OpenAIServingCompletion:
output_text = output["text"] output_text = output["text"]
choice_data = CompletionResponseChoice( choice_data = CompletionResponseChoice(
token_ids=token_ids,
index=len(choices), index=len(choices),
text=output_text, text=output_text,
reasoning_content=output.get('reasoning_content'), reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"), tool_calls=output.get("tool_call_content"),
logprobs=None, logprobs=None,
finish_reason=None finish_reason=None,
) )
choices.append(choice_data) choices.append(choice_data)

View File

@@ -42,7 +42,7 @@ response = client.completions.create(
) )
for chunk in response: for chunk in response:
print(chunk.choices[0].text, end='') print(chunk.choices[0].text, end="")
print("\n") print("\n")
# Chat completion # Chat completion
@@ -76,5 +76,5 @@ response = client.chat.completions.create(
for chunk in response: for chunk in response:
if chunk.choices[0].delta is not None: if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta, end='') print(chunk.choices[0].delta, end="")
print("\n") print("\n")

View File

@@ -20,115 +20,62 @@ from typing import Any, Callable
environment_variables: dict[str, Callable[[], Any]] = { environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use BF16 on CPU. # Whether to use BF16 on CPU.
"FD_CPU_USE_BF16": "FD_CPU_USE_BF16": lambda: os.getenv("FD_CPU_USE_BF16", "False"),
lambda: os.getenv("FD_CPU_USE_BF16", "False"),
# Cuda architecture to build FastDeploy.This is a list of strings # Cuda architecture to build FastDeploy.This is a list of strings
# such as [80,90]. # such as [80,90].
"FD_BUILDING_ARCS": "FD_BUILDING_ARCS": lambda: os.getenv("FD_BUILDING_ARCS", "[]"),
lambda: os.getenv("FD_BUILDING_ARCS", "[]"),
# Log directory. # Log directory.
"FD_LOG_DIR": "FD_LOG_DIR": lambda: os.getenv("FD_LOG_DIR", "log"),
lambda: os.getenv("FD_LOG_DIR", "log"),
# Whether to use debug mode, can set 0 or 1 # Whether to use debug mode, can set 0 or 1
"FD_DEBUG": "FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
lambda: os.getenv("FD_DEBUG", "0"),
# Number of days to keep fastdeploy logs. # Number of days to keep fastdeploy logs.
"FD_LOG_BACKUP_COUNT": "FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
# Model download cache directory. # Model download cache directory.
"FD_MODEL_CACHE": "FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
lambda: os.getenv("FD_MODEL_CACHE", None),
# Maximum number of stop sequences. # Maximum number of stop sequences.
"FD_MAX_STOP_SEQS_NUM": "FD_MAX_STOP_SEQS_NUM": lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
# Maximum length of stop sequences. # Maximum length of stop sequences.
"FD_STOP_SEQS_MAX_LEN": "FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
# GPU devices that will be used. This is a string that # GPU devices that will be used. This is a string that
# splited by comma, such as 0,1,2. # splited by comma, such as 0,1,2.
"CUDA_VISIBLE_DEVICES": "CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
# Whether to use HuggingFace tokenizer. # Whether to use HuggingFace tokenizer.
"FD_USE_HF_TOKENIZER": "FD_USE_HF_TOKENIZER": lambda: os.getenv("FD_USE_HF_TOKENIZER", 0),
lambda: os.getenv("FD_USE_HF_TOKENIZER", 0),
# Set the high watermark (HWM) for receiving data during ZMQ initialization # Set the high watermark (HWM) for receiving data during ZMQ initialization
"FD_ZMQ_SNDHWM": "FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 10000),
lambda: os.getenv("FD_ZMQ_SNDHWM", 10000),
# cache kv quant params directory # cache kv quant params directory
"FD_CACHE_PARAMS": "FD_CACHE_PARAMS": lambda: os.getenv("FD_CACHE_PARAMS", "none"),
lambda: os.getenv("FD_CACHE_PARAMS", "none"),
# Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN"
# and "MLA_ATTN" can be set currently. # and "MLA_ATTN" can be set currently.
"FD_ATTENTION_BACKEND": "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
"FD_SAMPLING_CLASS": "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
# Set moe backend."cutlass","marlin" and "triton" can be set currently. # Set moe backend."cutlass","marlin" and "triton" can be set currently.
"FD_MOE_BACKEND": "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Set whether to disable recompute the request when the KV cache is full. # Set whether to disable recompute the request when the KV cache is full.
"FD_DISABLED_RECOVER": "FD_DISABLED_RECOVER": lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
# Set triton kernel JIT compilation directory. # Set triton kernel JIT compilation directory.
"FD_TRITON_KERNEL_CACHE_DIR": "FD_TRITON_KERNEL_CACHE_DIR": lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
# Whether transition from standalone PD decoupling to centralized inference # Whether transition from standalone PD decoupling to centralized inference
"FD_PD_CHANGEABLE": "FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
# Whether to use fastsafetensor load weight (0 or 1) # Whether to use fastsafetensor load weight (0 or 1)
"FD_USE_FASTSAFETENSOR": "FD_USE_FASTSAFETENSOR": lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
# Whether to use DeepGemm for FP8 blockwise MoE. # Whether to use DeepGemm for FP8 blockwise MoE.
"FD_USE_DEEP_GEMM": "FD_USE_DEEP_GEMM": lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
# Whether to use aggregate send. # Whether to use aggregate send.
"FD_USE_AGGREGATE_SEND": "FD_USE_AGGREGATE_SEND": lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
# Whether to open Trace. # Whether to open Trace.
"TRACES_ENABLE": "TRACES_ENABLE": lambda: os.getenv("TRACES_ENABLE", "false"),
lambda: os.getenv("TRACES_ENABLE", "false"),
# set traec Server name. # set traec Server name.
"FD_SERVICE_NAME": "FD_SERVICE_NAME": lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
# set traec host name. # set traec host name.
"FD_HOST_NAME": "FD_HOST_NAME": lambda: os.getenv("FD_HOST_NAME", "localhost"),
lambda: os.getenv("FD_HOST_NAME", "localhost"),
# set traec exporter. # set traec exporter.
"TRACES_EXPORTER": "TRACES_EXPORTER": lambda: os.getenv("TRACES_EXPORTER", "console"),
lambda: os.getenv("TRACES_EXPORTER", "console"),
# set traec exporter_otlp_endpoint. # set traec exporter_otlp_endpoint.
"EXPORTER_OTLP_ENDPOINT": "EXPORTER_OTLP_ENDPOINT": lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
# set traec exporter_otlp_headers. # set traec exporter_otlp_headers.
"EXPORTER_OTLP_HEADERS": "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
} }

View File

@@ -43,8 +43,7 @@ def import_custom_ops(package, module_name, global_ns):
logger.warning(f"Failed to import op {func_name}: {e}") logger.warning(f"Failed to import op {func_name}: {e}")
except Exception: except Exception:
logger.warning( logger.warning(f"Ops of {package} import failed, it may be not compiled.")
f"Ops of {package} import failed, it may be not compiled.")
preprocess_static_op(global_ns) preprocess_static_op(global_ns)

View File

@@ -69,12 +69,12 @@ class ErnieProcessor(BaseDataProcessor):
# Generation config # Generation config
try: try:
self.generation_config = GenerationConfig.from_pretrained( self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
self.model_name_or_path)
except Exception as e: except Exception as e:
data_processor_logger.warning( data_processor_logger.warning(
f"Can't find generation config, so it will not use " f"Can't find generation config, so it will not use "
f"generation_config field in the model config, details={e}") f"generation_config field in the model config, details={e}"
)
self.generation_config = None self.generation_config = None
def process_request(self, request, max_model_len=None, **kwargs): def process_request(self, request, max_model_len=None, **kwargs):
@@ -89,8 +89,7 @@ class ErnieProcessor(BaseDataProcessor):
str: error message str: error message
""" """
request = self._apply_default_parameters(request) request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len( if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids request.eos_token_ids = self.eos_token_ids
stop_sequences = request.get("stop", []) stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0: if stop_sequences is not None and len(stop_sequences) != 0:
@@ -98,12 +97,9 @@ class ErnieProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs) request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len) request.set("stop_seqs_len", stop_seqs_len)
if request.prompt_token_ids is None or len( if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
request.prompt_token_ids) == 0:
system = request.get("system")
if request.prompt is None and request.messages is None: if request.prompt is None and request.messages is None:
raise ValueError( raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
f"The request should have `input_ids`, `text` or `messages`: {request}.")
if request.prompt is not None or not request.raw_request: if request.prompt is not None or not request.raw_request:
prompt = request.prompt if request.prompt is not None else request.messages[0] prompt = request.prompt if request.prompt is not None else request.messages[0]
prompt = prompt[0] if isinstance(prompt, list) else prompt prompt = prompt[0] if isinstance(prompt, list) else prompt
@@ -114,14 +110,13 @@ class ErnieProcessor(BaseDataProcessor):
else: else:
request.prompt_token_ids = self.messages2ids(request.to_dict()) request.prompt_token_ids = self.messages2ids(request.to_dict())
if max_model_len is not None and len( if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids) > max_model_len: request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
request.prompt_token_ids = request.prompt_token_ids[:
max_model_len -
1]
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request.set("max_tokens", request.set(
max(1, max_model_len - len(request.prompt_token_ids))) "max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)),
)
if request.get("temperature") < _SAMPLING_EPS: if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling # zero temperature is equivalent to greedy sampling
request.set("temperature", 1) request.set("temperature", 1)
@@ -140,45 +135,36 @@ class ErnieProcessor(BaseDataProcessor):
str: error message str: error message
""" """
request = self._apply_default_parameters(request) request = self._apply_default_parameters(request)
if not request.get('eos_token_ids'): if not request.get("eos_token_ids"):
request['eos_token_ids'] = self.eos_token_ids request["eos_token_ids"] = self.eos_token_ids
# 处理stop_sequences # 处理stop_sequences
stop_sequences = request.get('stop', []) stop_sequences = request.get("stop", [])
if stop_sequences: if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences) stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request['stop_token_ids'] = stop_seqs request["stop_token_ids"] = stop_seqs
request['stop_seqs_len'] = stop_seqs_len request["stop_seqs_len"] = stop_seqs_len
system = request.get("system")
# 处理prompt_token_ids # 处理prompt_token_ids
if not request.get('prompt_token_ids'): if not request.get("prompt_token_ids"):
if request.get('prompt') is None and request.get( if request.get("prompt") is None and request.get("messages") is None:
'messages') is None: raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
raise ValueError( if request.get("prompt"):
f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}" prompt = request.get("prompt")
)
if request.get('prompt'):
prompt = request.get('prompt')
prompt = prompt[0] if isinstance(prompt, list) else prompt prompt = prompt[0] if isinstance(prompt, list) else prompt
tokens = self.tokenizer.tokenize(prompt) tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request['prompt_token_ids'] = token_ids request["prompt_token_ids"] = token_ids
req_id = request.get("request_id", None) req_id = request.get("request_id", None)
data_processor_logger.info( data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}"
)
else: else:
request['prompt_token_ids'] = self.messages2ids(request) request["prompt_token_ids"] = self.messages2ids(request)
# 截断超过长度限制的prompt # 截断超过长度限制的prompt
if max_model_len is not None and len( if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request['prompt_token_ids']) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
request['prompt_token_ids'] = request[
'prompt_token_ids'][:max_model_len - 1]
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request["max_tokens"] = max( request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
1, max_model_len - len(request['prompt_token_ids']))
if request.get("temperature") < _SAMPLING_EPS: if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling # zero temperature is equivalent to greedy sampling
request["temperature"] = 1 request["temperature"] = 1
@@ -200,22 +186,18 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict.request_id req_id = response_dict.request_id
token_ids = response_dict.outputs.token_ids token_ids = response_dict.outputs.token_ids
response_dict.usage = { response_dict.usage = {"completion_tokens": response_dict.outputs.index + 1}
"completion_tokens": response_dict.outputs.index + 1
}
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
full_text = self.tokenizer.decode(token_ids) full_text = self.tokenizer.decode(token_ids)
if self.reasoning_parser: if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content( reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
full_text, response_dict)
response_dict.outputs.text = text response_dict.outputs.text = text
response_dict.outputs.reasoning_content = reasoning_content response_dict.outputs.reasoning_content = reasoning_content
else: else:
response_dict.outputs.text = full_text response_dict.outputs.text = full_text
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}") data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and \ if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
response_dict.outputs.reasoning_content == "":
return None return None
return response_dict return response_dict
@@ -230,8 +212,7 @@ class ErnieProcessor(BaseDataProcessor):
Dict: response contain text fields Dict: response contain text fields
""" """
if stream: if stream:
return self.process_response_dict_streaming( return self.process_response_dict_streaming(response_dict, **kwargs)
response_dict, **kwargs)
else: else:
return self.process_response_dict_normal(response_dict, **kwargs) return self.process_response_dict_normal(response_dict, **kwargs)
@@ -255,16 +236,12 @@ class ErnieProcessor(BaseDataProcessor):
if is_end: if is_end:
full_text = previous_texts + delta_text full_text = previous_texts + delta_text
if self.reasoning_parser: if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content( reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
full_text, response_dict)
response_dict["outputs"]["text"] = text response_dict["outputs"]["text"] = text
response_dict["outputs"][ response_dict["outputs"]["reasoning_content"] = reasoning_content
"reasoning_content"] = reasoning_content
else: else:
response_dict["outputs"]["text"] = full_text response_dict["outputs"]["text"] = full_text
data_processor_logger.info( data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}"
)
del self.decode_status[req_id] del self.decode_status[req_id]
return response_dict return response_dict
@@ -286,20 +263,22 @@ class ErnieProcessor(BaseDataProcessor):
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens( delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
token_ids, req_id)
if enable_thinking and self.reasoning_parser: if enable_thinking and self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming( reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts, previous_texts + delta_text, delta_text, previous_texts,
previous_token_ids, previous_token_ids + token_ids, token_ids) previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
)
response_dict["outputs"]["text"] = text response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content response_dict["outputs"]["reasoning_content"] = reasoning_content
else: else:
response_dict["outputs"]["text"] = delta_text response_dict["outputs"]["text"] = delta_text
if is_end: if is_end:
data_processor_logger.info( data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}"
)
del self.decode_status[req_id] del self.decode_status[req_id]
return response_dict return response_dict
@@ -320,15 +299,15 @@ class ErnieProcessor(BaseDataProcessor):
request_or_messages, request_or_messages,
tokenize=False, tokenize=False,
split_special_tokens=False, split_special_tokens=False,
add_special_tokens=False) add_special_tokens=False,
)
req_id = None req_id = None
if isinstance(request_or_messages, dict): if isinstance(request_or_messages, dict):
req_id = request_or_messages.get("request_id", None) req_id = request_or_messages.get("request_id", None)
tokens = self.tokenizer.tokenize(spliced_message) tokens = self.tokenizer.tokenize(spliced_message)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info( data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
return token_ids return token_ids
def ids2tokens(self, token_id, task_id): def ids2tokens(self, token_id, task_id):
@@ -352,7 +331,8 @@ class ErnieProcessor(BaseDataProcessor):
previous_token_ids = self.decode_status[task_id][2] previous_token_ids = self.decode_status[task_id][2]
previous_texts = self.decode_status[task_id][3] previous_texts = self.decode_status[task_id][3]
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token( decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(
previous_token_ids + token_id, prefix_offset, read_offset) previous_token_ids + token_id, prefix_offset, read_offset
)
self.decode_status[task_id][0] = prefix_offset self.decode_status[task_id][0] = prefix_offset
self.decode_status[task_id][1] = read_offset self.decode_status[task_id][1] = read_offset
self.decode_status[task_id][2] += token_id self.decode_status[task_id][2] += token_id
@@ -368,17 +348,15 @@ class ErnieProcessor(BaseDataProcessor):
tokenizer (AutoTokenizer) tokenizer (AutoTokenizer)
""" """
vocab_file_names = [ vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model" "tokenizer.model",
"spm.model",
"ernie_token_100k.model",
] ]
for i in range(len(vocab_file_names)): for i in range(len(vocab_file_names)):
if os.path.exists( if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
os.path.join(self.model_name_or_path, ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names[
"vocab_file"] = vocab_file_names[i]
break break
self.tokenizer = ErnieBotTokenizer.from_pretrained( self.tokenizer = ErnieBotTokenizer.from_pretrained(self.model_name_or_path)
self.model_name_or_path)
def get_pad_id(self): def get_pad_id(self):
""" """
@@ -391,16 +369,17 @@ class ErnieProcessor(BaseDataProcessor):
# return self.tokenizer.eos_token # return self.tokenizer.eos_token
return self.tokenizer.pad_token_id return self.tokenizer.pad_token_id
def pad_batch_data(self, def pad_batch_data(
self,
insts, insts,
pad_id=0, pad_id=0,
return_seq_len=False, return_seq_len=False,
return_array=True, return_array=True,
pad_style="right"): pad_style="right",
):
"""Pad the instances to the max sequence length in batch.""" """Pad the instances to the max sequence length in batch."""
if len(insts) == 0: if len(insts) == 0:
padded_insts = np.array([[]], padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
dtype=np.int64) if return_array else [[]]
if return_seq_len: if return_seq_len:
seq_len = np.array([], dtype=np.int64) if return_array else [] seq_len = np.array([], dtype=np.int64) if return_array else []
return padded_insts, seq_len return padded_insts, seq_len
@@ -408,15 +387,11 @@ class ErnieProcessor(BaseDataProcessor):
max_len = max(map(len, insts)) max_len = max(map(len, insts))
if pad_style == "left": if pad_style == "left":
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
for inst in insts]
else: else:
padded_insts = [ padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts
]
if return_array: if return_array:
padded_insts = np.array(padded_insts, padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
dtype=np.int64).reshape([-1, max_len])
if return_seq_len: if return_seq_len:
seq_len = [len(inst) for inst in insts] seq_len = [len(inst) for inst in insts]
@@ -432,15 +407,9 @@ class ErnieProcessor(BaseDataProcessor):
stop_seqs = [] stop_seqs = []
for seq in stop_sequences: for seq in stop_sequences:
if seq != self.tokenizer.eos_token_id: if seq != self.tokenizer.eos_token_id:
stop_seqs.append( stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
self.tokenizer.convert_tokens_to_ids( stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
self.tokenizer.tokenize(seq))) data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs,
pad_id=-1,
return_seq_len=True,
return_array=False)
data_processor_logger.debug(
f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
return stop_seqs, stop_seqs_len return stop_seqs, stop_seqs_len
def process_logprob_response(self, token_ids, **kwargs): def process_logprob_response(self, token_ids, **kwargs):

View File

@@ -19,19 +19,14 @@
import os import os
import re import re
from shutil import copyfile from shutil import copyfile
from typing import Dict, Optional, Tuple, List from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import sentencepiece as spm
import paddle import paddle
import sentencepiece as spm
from paddleformers.utils.log import logger
from paddleformers.transformers import PretrainedTokenizer from paddleformers.transformers import PretrainedTokenizer
from paddleformers.transformers.tokenizer_utils_base import ( from paddleformers.transformers.tokenizer_utils_base import PaddingStrategy, TextInput
PaddingStrategy, from paddleformers.utils.log import logger
TextInput,
)
class ErnieBotTokenizer(PretrainedTokenizer): class ErnieBotTokenizer(PretrainedTokenizer):
@@ -47,7 +42,12 @@ class ErnieBotTokenizer(PretrainedTokenizer):
pretrained_init_configuration = { pretrained_init_configuration = {
"ernie-bot-10b": {}, "ernie-bot-10b": {},
} }
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] model_input_names = [
"input_ids",
"position_ids",
"attention_mask",
"labels",
]
padding_side = "right" padding_side = "right"
def __init__( def __init__(
@@ -222,9 +222,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
# TODO: should this be in the base class? # TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case: if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [ escaped_special_toks = [re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)]
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
@@ -303,7 +301,12 @@ class ErnieBotTokenizer(PretrainedTokenizer):
elif not isinstance(attention_mask, np.ndarray): elif not isinstance(attention_mask, np.ndarray):
raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ") raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
else: else:
attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64)) attention_mask = np.tril(
np.ones(
(len(required_input), len(required_input)),
dtype=np.int64,
)
)
attention_mask = np.expand_dims(attention_mask, axis=0) attention_mask = np.expand_dims(attention_mask, axis=0)
if needs_to_be_padded: if needs_to_be_padded:
difference = max_length - len(required_input) difference = max_length - len(required_input)

View File

@@ -17,18 +17,23 @@
import os import os
import numpy as np import numpy as np
import re
from fastdeploy.input.mm_processor import DataProcessor, IDS_TYPE_FLAG
from fastdeploy.input.ernie_processor import ErnieProcessor
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.entrypoints.chat_utils import parse_chat_messages from fastdeploy.input.ernie_processor import ErnieProcessor
from fastdeploy.input.mm_processor import IDS_TYPE_FLAG, DataProcessor
from fastdeploy.utils import data_processor_logger from fastdeploy.utils import data_processor_logger
class ErnieMoEVLProcessor(ErnieProcessor): class ErnieMoEVLProcessor(ErnieProcessor):
"""The processor class for ERNIE MoE VL models.""" """The processor class for ERNIE MoE VL models."""
def __init__(self, model_name_or_path, limit_mm_per_prompt=None, mm_processor_kwargs=None,
reasoning_parser_obj=None): def __init__(
self,
model_name_or_path,
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
):
self.use_hf_tokenizer = False self.use_hf_tokenizer = False
if "merge_llm_model" in model_name_or_path: if "merge_llm_model" in model_name_or_path:
@@ -41,7 +46,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.ernie_processor = DataProcessor( self.ernie_processor = DataProcessor(
tokenizer_name=tokenizer_path, tokenizer_name=tokenizer_path,
image_preprocessor_name=preprocessor_path, image_preprocessor_name=preprocessor_path,
**processor_kwargs **processor_kwargs,
) )
self.ernie_processor.eval() self.ernie_processor.eval()
self.image_patch_id = self.ernie_processor.image_patch_id self.image_patch_id = self.ernie_processor.image_patch_id
@@ -73,7 +78,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def process_request(self, request, max_model_len=None, **kwargs): def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data""" """process the input data"""
task = request.to_dict() task = request.to_dict()
task['enable_thinking'] = kwargs.get("enable_thinking", True) task["enable_thinking"] = kwargs.get("enable_thinking", True)
self.process_request_dict(task, max_model_len) self.process_request_dict(task, max_model_len)
request = Request.from_dict(task) request = Request.from_dict(task)
@@ -101,13 +106,14 @@ class ErnieMoEVLProcessor(ErnieProcessor):
"video_frames_sample": str, "video_frames_sample": str,
"video_max_frames": int, "video_max_frames": int,
"video_min_frames": int, "video_min_frames": int,
"video_fps": int "video_fps": int,
} }
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in expected_types and not isinstance(value, expected_types[key]): if key in expected_types and not isinstance(value, expected_types[key]):
raise ValueError( raise ValueError(
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}") f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}"
)
return kwargs return kwargs
@@ -117,11 +123,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def _parse_limits(self, limits): def _parse_limits(self, limits):
"""解析多模态限制配置""" """解析多模态限制配置"""
DEFAULT_LIMITS = { DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1}
"image": 1,
"video": 1,
"audio": 1
}
if not limits: if not limits:
return DEFAULT_LIMITS return DEFAULT_LIMITS
@@ -141,10 +143,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
mm_data = item mm_data = item
else: else:
# 请求包含messages # 请求包含messages
mm_data = { mm_data = {"image": [], "video": []}
"image": [],
"video": []
}
for message in item: for message in item:
if isinstance(message.get("content"), list): if isinstance(message.get("content"), list):
@@ -158,10 +157,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
if modality in self.limit_mm_per_prompt: if modality in self.limit_mm_per_prompt:
limit = self.limit_mm_per_prompt[modality] limit = self.limit_mm_per_prompt[modality]
if len(data) > limit: if len(data) > limit:
raise ValueError( raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}")
f"Too many {modality} items in prompt, "
f"got {len(data)} but limit is {limit}"
)
def process_request_dict(self, request, max_model_len=None): def process_request_dict(self, request, max_model_len=None):
"""process the input data""" """process the input data"""
@@ -200,13 +196,10 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request["multimodal_inputs"] = outputs request["multimodal_inputs"] = outputs
# 截断超过长度限制的prompt # 截断超过长度限制的prompt
if max_model_len is not None and len( if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request['prompt_token_ids']) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
request['prompt_token_ids'] = request[
'prompt_token_ids'][:max_model_len - 1]
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request["max_tokens"] = max( request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
1, max_model_len - len(request['prompt_token_ids']))
data_processor_logger.info(f"Processed request {request}") data_processor_logger.info(f"Processed request {request}")
return request return request

View File

@@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
""" """
from .process import DataProcessor, fancy_print, IDS_TYPE_FLAG from .process import IDS_TYPE_FLAG, DataProcessor, fancy_print
__all__ = [ __all__ = [
'DataProcessor', "DataProcessor",
'fancy_print', "fancy_print",
'IDS_TYPE_FLAG', "IDS_TYPE_FLAG",
] ]

View File

@@ -17,4 +17,4 @@
from .get_image_preprocessor import get_image_preprocessor from .get_image_preprocessor import get_image_preprocessor
from .image_preprocessor_adaptive import AdaptiveImageProcessor from .image_preprocessor_adaptive import AdaptiveImageProcessor
__all__ = ['get_image_preprocessor', 'AdaptiveImageProcessor'] __all__ = ["get_image_preprocessor", "AdaptiveImageProcessor"]

View File

@@ -16,9 +16,10 @@
"""get image preprocessor""" """get image preprocessor"""
from .image_preprocessor_adaptive import AdaptiveImageProcessor
from fastdeploy.utils import data_processor_logger from fastdeploy.utils import data_processor_logger
from .image_preprocessor_adaptive import AdaptiveImageProcessor
def get_image_preprocessor(args): def get_image_preprocessor(args):
""" """

Some files were not shown because too many files have changed in this diff Show More