polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

7
.flake8 Normal file
View File

@@ -0,0 +1,7 @@
[flake8]
ignore = E203, E402, E501, E731, E741, W503, W605, E722
max-line-length = 119
# E402: module level import not at top of file
per-file-ignores =
__init__.py:F401,F403,E402

View File

@@ -5,12 +5,27 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
- repo: https://github.com/psf/black.git
rev: 22.8.0
hooks:
- id: black
files: \.(py|pyi)$
additional_dependencies: [toml]
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
- id: ruff
args: [--output-format, github, --fix, --line-length=120]
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
# # 拼写检查
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
@@ -18,17 +33,13 @@ repos:
# - id: codespell
# additional_dependencies: ['tomli']
# args: ['--toml', 'pyproject.toml']
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29
hooks:
- id: pymarkdown
args: [fix]
args: ["-d", "MD029,MD031", fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:

View File

@@ -29,13 +29,13 @@ from typing import Optional
import aiohttp
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
"""Input for requesting LLMs via API"""
no: int
prompt: str
history_QA: Optional[dict]
@@ -55,6 +55,7 @@ class RequestFuncInput:
@dataclass
class RequestFuncOutput:
"""Output for requesting LLMs via API"""
no: int = 0
generated_text: str = ""
reasoning_content: str = ""
@@ -76,12 +77,9 @@ async def async_request_eb_openai_chat_completions(
) -> RequestFuncOutput:
"""Request an LLM using EB OpenAI"""
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Chat Completions API URL must end with 'completions'."
assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
@@ -91,7 +89,7 @@ async def async_request_eb_openai_chat_completions(
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
}
# 超参由yaml传入
@@ -100,7 +98,7 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
headers = {
"Content-Type": "application/json",
@@ -115,16 +113,14 @@ async def async_request_eb_openai_chat_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter()
@@ -138,22 +134,20 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.output_tokens = usage.get(
"completion_tokens", 0)
output.prompt_tokens = usage.get(
"prompt_tokens", 0)
output.output_tokens = usage.get("completion_tokens", 0)
output.prompt_tokens = usage.get("prompt_tokens", 0)
most_recent_timestamp = timestamp
@@ -166,7 +160,12 @@ async def async_request_eb_openai_chat_completions(
output.latency = most_recent_timestamp - st
else:
error_text = await response.text()
print("####error response:", error_text, "####payload:", payload)
print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or ""
output.success = False
except Exception:
@@ -194,15 +193,14 @@ async def async_request_eb_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
}
# 超参由yaml传入
@@ -215,7 +213,7 @@ async def async_request_eb_openai_completions(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
output = RequestFuncOutput()
@@ -227,8 +225,7 @@ async def async_request_eb_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@@ -236,8 +233,7 @@ async def async_request_eb_openai_completions(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter()
@@ -259,25 +255,22 @@ async def async_request_eb_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
generated_text += text or ""
most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage"):
output.prompt_tokens = usage.get(
"prompt_tokens")
output.output_tokens = usage.get(
"completion_tokens")
output.prompt_tokens = usage.get("prompt_tokens")
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
@@ -295,7 +288,7 @@ async def async_request_eb_openai_completions(
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output))
print(f"final_output:{output}")
if pbar:
pbar.update(1)
@@ -310,8 +303,7 @@ async def async_request_tgi(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
params = {
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
@@ -358,8 +350,7 @@ async def async_request_tgi(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
output.arrival_time.append(data["arrival_time"])
@@ -388,8 +379,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
@@ -414,8 +404,7 @@ async def async_request_trt_llm(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
@@ -427,8 +416,7 @@ async def async_request_trt_llm(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
@@ -453,8 +441,7 @@ async def async_request_deepspeed_mii(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
"""Request an LLM using Deepspeed MII"""
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"prompt": request_func_input.prompt,
@@ -472,19 +459,16 @@ async def async_request_deepspeed_mii(
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as response:
async with session.post(url=request_func_input.api_url, json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][
"text"]
output.generated_text = parsed_resp["choices"][0]["text"]
elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0]
else:
output.error = ("Unexpected response format: "
"neither 'choices' nor 'text' found")
output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
output.success = False
output.success = True
else:
@@ -510,11 +494,9 @@ async def async_request_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
"prompt": request_func_input.prompt,
# "temperature": 0.0,
"max_tokens": request_func_input.output_len,
@@ -527,9 +509,7 @@ async def async_request_openai_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -538,8 +518,7 @@ async def async_request_openai_completions(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
async for chunk_bytes in response.content:
@@ -547,8 +526,7 @@ async def async_request_openai_completions(
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk))
data = json.loads(chunk)
@@ -569,21 +547,19 @@ async def async_request_openai_completions(
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
"Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
@@ -606,25 +582,24 @@ async def async_request_openai_audio(
"""Request an LLM using OpenAI"""
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
assert api_url.endswith(
("transcriptions", "translations"
)), "OpenAI Chat Completions API URL must end with 'transcriptions' "
("transcriptions", "translations")
), "OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True
"stream_continuous_usage_stats": True,
}
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
@@ -639,9 +614,9 @@ async def async_request_openai_audio(
buffer.seek(0)
return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav')
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
@@ -653,24 +628,20 @@ async def async_request_openai_audio(
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url,
data=form,
headers=headers) as response:
async with session.post(url=api_url, data=form, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get(
"content")
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
@@ -678,13 +649,11 @@ async def async_request_openai_audio(
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp)
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
@@ -718,8 +687,11 @@ ASYNC_REQUEST_FUNCS = {
}
OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions,
async_request_eb_openai_chat_completions)
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v
in (
async_request_openai_completions,
async_request_eb_openai_chat_completions,
)
]

View File

@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Callable, Optional, Union
from PIL import Image
from typing import Any, Optional, Union
from PIL import Image
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
no: int
prompt: Union[str, Any]
history_QA: Union[str, Any]
@@ -48,6 +49,7 @@ class SampleRequest:
class BenchmarkDataset(ABC):
"""BenchmarkDataset"""
DEFAULT_SEED = 0
IS_MULTIMODAL = False
@@ -68,8 +70,7 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the
# default seed.
self.random_seed = (random_seed
if random_seed is not None else self.DEFAULT_SEED)
self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
self.data = None
self.hyperparameter_path = hyperparameter_path
self.hyperparameters = {}
@@ -85,8 +86,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method.
"""
# TODO (jenniferzhao): add support for downloading data
raise NotImplementedError(
"load_data must be implemented in subclasses.")
raise NotImplementedError("load_data must be implemented in subclasses.")
@abstractmethod
def sample(self, num_requests: int) -> list[SampleRequest]:
@@ -105,8 +105,7 @@ class BenchmarkDataset(ABC):
"""
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest],
num_requests: int) -> None:
def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
"""
Oversamples the list of requests if its size is less than the desired
number.
@@ -117,11 +116,9 @@ class BenchmarkDataset(ABC):
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
additional = random.choices(requests,
k=num_requests - len(requests))
additional = random.choices(requests, k=num_requests - len(requests))
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
logger.info("Oversampled requests to reach %d total samples.", num_requests)
def is_valid_sequence(
@@ -141,14 +138,12 @@ def is_valid_sequence(
"""
# Check for invalid conditions
prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len
< min_len)
output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long
or combined_too_long)
return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
def process_image(image: Any) -> Mapping[str, Any]:
@@ -171,28 +166,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes']))
if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image):
image = image.convert("RGB")
with io.BytesIO() as image_data:
image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
}
if isinstance(image, str):
image_url = (image if image.startswith(
("http://", "file://")) else f"file://{image}")
image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes.")
raise ValueError(
f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
)
class EBDataset(BenchmarkDataset):
@@ -243,8 +235,7 @@ class EBDataset(BenchmarkDataset):
new_output_len = int(entry["max_dec_len"])
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
no=cnt,
@@ -252,17 +243,20 @@ class EBDataset(BenchmarkDataset):
prompt_len=self.prompt_len,
history_QA=[],
expected_output_len=new_output_len,
))
)
)
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples
class EBChatDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns.
"""
prompt_len: int
def __init__(self, **kwargs) -> None:
@@ -296,8 +290,7 @@ class EBChatDataset(BenchmarkDataset):
new_output_len = int(entry.get("max_tokens", 12288))
if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation(
prompt, None)
prompt = self.apply_multimodal_chat_transformation(prompt, None)
samples.append(
SampleRequest(
no=cnt,
@@ -306,9 +299,9 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
))
)
)
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples

View File

@@ -18,28 +18,16 @@ import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union
import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_dataset import EBChatDataset, EBDataset
from benchmark_serving import benchmark
def prepare_input_requests(
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
}
try:
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
def main(args):
base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests(
args.num_prompts, args.dataset_name, args.dataset_path
)
input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup
print("Starting warmup...")
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
send_one_batch(
base_url,
max_concurrency,
input_requests[0:max_concurrency],
True,
)
# Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up"
metric_header = "Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup(
@@ -130,11 +121,7 @@ def main(args):
s_itl,
record["mean_s_itl_ms"],
)
print(
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
print("=" * 50)

View File

@@ -25,22 +25,23 @@ import os
import random
import time
import warnings
import yaml
from argparse import ArgumentParser as FlexibleArgumentParser
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput)
from tqdm.asyncio import tqdm
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
import yaml
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -48,6 +49,7 @@ MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass
class BenchmarkMetrics:
"""Class containing all metrics that are used in this script"""
completed: int
total_input: int
total_output: int
@@ -130,8 +132,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
theta = 1.0 / (request_rate * burstiness)
for request in input_requests:
@@ -208,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token
if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) /
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
s_decodes.append(
(outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
else:
print("len(outputs[i].arrival_time) <= 2")
completed += 1
@@ -224,16 +226,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +241,9 @@ def calculate_metrics(
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
"All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@@ -253,64 +252,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) *
1, # ttfts is empty if streaming is not supported by backend
mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1)
for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p))
for p in selected_percentiles],
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p))
for p in selected_percentiles],
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p))
for p in selected_percentiles],
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
)
return metrics, actual_output_lens
@@ -344,9 +329,11 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_output_len, test_no = \
input_requests[0].prompt, \
input_requests[0].expected_output_len, input_requests[0].no
test_prompt, test_output_len, test_no = (
input_requests[0].prompt,
input_requests[0].expected_output_len,
input_requests[0].no,
)
test_history_QA = input_requests[0].history_QA
test_input = RequestFuncInput(
@@ -373,19 +360,19 @@ async def benchmark(
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
no=test_no,
@@ -393,7 +380,8 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
@@ -413,21 +401,22 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
prompt, output_len, no = (
request.prompt,
request.expected_output_len,
request.no,
)
history_QA = request.history_QA
req_model_id, req_model_name = model_id, model_name
@@ -435,7 +424,8 @@ async def benchmark(
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
no=no,
@@ -446,11 +436,9 @@ async def benchmark(
output_len=output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
extra_body=extra_body,
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@@ -473,7 +461,6 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
print("benchmark_duration:", benchmark_duration)
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
@@ -483,22 +470,16 @@ async def benchmark(
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
@@ -506,8 +487,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@@ -533,24 +513,25 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
@@ -565,31 +546,31 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
getattr(metrics, f"median_{metric_attribute_name}"),
)
)
result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -612,7 +593,6 @@ def benchmark_metrics(
):
"""Benchmark metrics statisticsgenerate benchmark result"""
outputs = []
case_no_list = []
with open(result_file) as f:
for line in f.readlines():
if "RequestFuncOutput" in line:
@@ -634,22 +614,16 @@ def benchmark_metrics(
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
@@ -657,8 +631,7 @@ def benchmark_metrics(
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@@ -682,24 +655,25 @@ def benchmark_metrics(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
@@ -714,31 +688,31 @@ def benchmark_metrics(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
getattr(metrics, f"median_{metric_attribute_name}"),
)
)
result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -764,12 +738,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. ")
f"{VALID_NAMES!s}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
"non-negative."
)
return goodput_config_dict
@@ -783,32 +759,37 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
"number in milliseconds."
) from err
return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any],
file_name: str) -> None:
def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
"median_ttft_ms",
"mean_ttft_ms",
"std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
for k in metrics},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
metrics={k: [results[k]] for k in metrics},
extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
@@ -825,7 +806,6 @@ def main(args: argparse.Namespace):
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
@@ -835,21 +815,15 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"EB":
lambda: EBDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"EBChat":
lambda: EBChatDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
@@ -869,15 +843,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
"temperature": args.temperature
}.items() if v is not None
"temperature": args.temperature,
}.items()
if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError(
"Sampling parameters are only supported by openai-compatible "
"backends.")
raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -908,15 +881,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
)
)
# benchmark_result = benchmark_metrics(
# benchmark_duration=3600,
@@ -947,22 +919,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
if not args.save_detailed:
# Remove fields with too many data points
for field in [
"input_lens", "output_lens", "ttfts", "itls",
"generated_texts", "errors"
"input_lens",
"output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]:
if field in result_json:
del result_json[field]
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@@ -971,21 +944,19 @@ def main(args: argparse.Namespace):
# Save to file
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile:
with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
@@ -1011,18 +982,29 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"],
choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument("--hyperparameter-path",
help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--hyperparameter-path",
type=str,
default=None,
help="Path to the hyperparameter. ")
help="Path to the hyperparameter. ",
)
parser.add_argument(
"--max-concurrency",
type=int,
@@ -1034,7 +1016,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
@@ -1045,7 +1028,7 @@ if __name__ == "__main__":
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@@ -1058,11 +1041,13 @@ if __name__ == "__main__":
"--logprobs",
type=int,
default=None,
help=("Number of logprobs-per-token to compute & return as part of "
help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
"is enabled 1 logprob per token is computed"
),
)
parser.add_argument(
"--request-rate",
@@ -1099,8 +1084,7 @@ if __name__ == "__main__":
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
@@ -1141,35 +1125,38 @@ if __name__ == "__main__":
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl\".")
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1197,8 +1184,8 @@ if __name__ == "__main__":
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
@@ -1226,29 +1213,24 @@ if __name__ == "__main__":
"--random-prefix-len",
type=int,
default=0,
help=("Number of fixed prefix tokens before the random context "
help=(
"Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
"input_len * (1 + range_ratio)]."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
)
sampling_group = parser.add_argument_group("sampling parameters")
@@ -1256,54 +1238,59 @@ if __name__ == "__main__":
"--top-p",
type=float,
default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).")
"decoding (i.e. temperature==0.0).",
)
parser.add_argument(
'--tokenizer-mode',
"--tokenizer-mode",
type=str,
default="auto",
choices=['auto', 'slow', 'mistral', 'custom'],
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name",
parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")
"same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules",
nargs='+',
parser.add_argument(
"--lora-modules",
nargs="+",
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
"script chooses a LoRA module at random.",
)
args = parser.parse_args()
main(args)

View File

@@ -24,9 +24,11 @@ import os
from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace,
metrics: dict[str, list],
extra_info: dict[str, Any]) -> list:
extra_info: dict[str, Any],
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
},
}
tp = record["benchmark"]["extra_info"]["args"].get(
"tensor_parallel_size")
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
records.append(record)
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder):
"""InfEncoder"""
def clear_inf(self, o: Any):
"""clear_inf"""
if isinstance(o, dict):
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
"""write_to_json"""
with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder)

View File

@@ -25,32 +25,32 @@ import os
import random
import time
import warnings
import yaml
import requests
import copy
from argparse import ArgumentParser as FlexibleArgumentParser
from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput)
import requests
import yaml
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass
class BenchmarkMetrics:
"""Class containing all metrics that are used in this script"""
completed: int
total_input: int
total_output: int
@@ -133,8 +133,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}.")
assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
theta = 1.0 / (request_rate * burstiness)
for request in input_requests:
@@ -210,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token
if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) /
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
s_decodes.append(
(outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
completed += 1
else:
actual_output_lens.append(0)
@@ -224,16 +224,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION)
slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +239,9 @@ def calculate_metrics(
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
"All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@@ -253,64 +250,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) *
1, # ttfts is empty if streaming is not supported by backend
mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1)
for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
for p in selected_percentiles],
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000)
for p in selected_percentiles],
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p))
for p in selected_percentiles],
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p))
for p in selected_percentiles],
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p))
for p in selected_percentiles],
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
)
return metrics, actual_output_lens
@@ -351,20 +334,22 @@ async def benchmark(
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
test_prompt = None
test_output_len = None
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
output_len=test_output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
@@ -384,16 +369,13 @@ async def benchmark(
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
pbar=pbar)
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
@@ -409,7 +391,8 @@ async def benchmark(
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id,
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
prompt_len=0,
@@ -419,15 +402,15 @@ async def benchmark(
output_len=output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
pbar=pbar)))
extra_body=extra_body,
)
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
print(f"完成时间:{datetime.now()}")
if profile:
print("Stopping profiler...")
test_output_len = None
test_output_len = None
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
@@ -454,22 +437,16 @@ async def benchmark(
)
print("Benchmark complete!!!")
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
@@ -477,8 +454,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
@@ -491,7 +467,6 @@ async def benchmark(
"reasoning_contents": [output.reasoning_content for output in outputs],
"errors": [output.error for output in outputs],
}
quick_result = copy.deepcopy(result)
def process_one_metric(
# E.g., "ttft"
@@ -505,24 +480,25 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
@@ -537,31 +513,31 @@ async def benchmark(
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
getattr(metrics, f"median_{metric_attribute_name}"),
)
)
result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -581,6 +557,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
"""
快速评估
"""
def process_quick_metric(
metric_attribute_name: str,
metric_name: str,
@@ -588,7 +565,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
):
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
@@ -600,17 +577,17 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
):
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
quick_result[f"mean_{metric_attribute_name}"] = mean_value
print("\n\n\n")
print("{s:{c}^{n}}".format(s=' Benchmark Quick Summary ', n=50, c='='))
print("{s:{c}^{n}}".format(s=" Benchmark Quick Summary ", n=50, c="="))
process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
process_quick_metric("ttft", "TTFT", "Time to First Token")
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_quick_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_quick_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_quick_metric("itl", "ITL", "Inter-token Latency")
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_quick_metric("e2el", "E2EL", "End-to-end Latency")
@@ -633,12 +610,14 @@ def check_goodput_args(args):
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. ")
f"{VALID_NAMES!s}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative.")
"non-negative."
)
return goodput_config_dict
@@ -652,37 +631,43 @@ def parse_goodput(slo_pairs):
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err
"number in milliseconds."
) from err
return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any],
file_name: str) -> None:
def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
"median_ttft_ms",
"mean_ttft_ms",
"std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
for k in metrics},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
metrics={k: [results[k]] for k in metrics},
extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
)
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def check_health(api_base_url: str) -> bool:
health_url = api_base_url.rstrip("/") + "/health"
try:
@@ -697,6 +682,7 @@ def check_health(api_base_url: str) -> bool:
print(f"[HEALTH] Failed to connect to {health_url}: {e}")
return False
def main(args: argparse.Namespace):
"""Main entry point"""
print(args)
@@ -707,7 +693,6 @@ def main(args: argparse.Namespace):
model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
@@ -717,21 +702,15 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None:
raise ValueError(
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping.
dataset_mapping = {
"EB":
lambda: EBDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
"EBChat":
lambda: EBChatDataset(random_seed=args.seed,
dataset_path=args.dataset_path).sample(
"EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
),
@@ -751,15 +730,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p,
"top_k": args.top_k,
"min_p": args.min_p,
"temperature": args.temperature
}.items() if v is not None
"temperature": args.temperature,
}.items()
if v is not None
}
# Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError(
"Sampling parameters are only supported by openai-compatible "
"backends.")
raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -790,15 +768,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[
float(p) for p in args.metric_percentiles.split(",")
],
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
)
)
# Save config and results to json
if args.save_result:
@@ -819,22 +796,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid metadata format. Please use KEY=VALUE format."
)
raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
if not args.save_detailed:
# Remove fields with too many data points
for field in [
"input_lens", "output_lens", "ttfts", "itls",
"generated_texts", "errors"
"input_lens",
"output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]:
if field in result_json:
del result_json[field]
# Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate
< float("inf") else "inf")
result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
@@ -843,21 +821,19 @@ def main(args: argparse.Namespace):
# Save to file
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile:
with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
@@ -883,18 +859,29 @@ if __name__ == "__main__":
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"],
choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--dataset-path",
parser.add_argument(
"--dataset-path",
type=str,
default=None,
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.")
parser.add_argument("--hyperparameter-path",
help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--hyperparameter-path",
type=str,
default=None,
help="Path to the hyperparameter. ")
help="Path to the hyperparameter. ",
)
parser.add_argument(
"--max-concurrency",
type=int,
@@ -906,7 +893,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.")
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
@@ -917,7 +905,7 @@ if __name__ == "__main__":
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
@@ -930,11 +918,13 @@ if __name__ == "__main__":
"--logprobs",
type=int,
default=None,
help=("Number of logprobs-per-token to compute & return as part of "
help=(
"Number of logprobs-per-token to compute & return as part of "
"the request. If unspecified, then either (1) if beam search "
"is disabled, no logprobs are computed & a single dummy "
"logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"),
"is enabled 1 logprob per token is computed"
),
)
parser.add_argument(
"--request-rate",
@@ -971,8 +961,7 @@ if __name__ == "__main__":
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
@@ -1013,35 +1002,38 @@ if __name__ == "__main__":
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
"Default value is \"ttft,tpot,itl\".")
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
"Default value is \"99\". "
"Use \"--percentile-metrics\" to select metrics.",
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1069,8 +1061,8 @@ if __name__ == "__main__":
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
@@ -1098,29 +1090,24 @@ if __name__ == "__main__":
"--random-prefix-len",
type=int,
default=0,
help=("Number of fixed prefix tokens before the random context "
help=(
"Number of fixed prefix tokens before the random context "
"in a request. "
"The total input length is the sum of `random-prefix-len` and "
"a random "
"context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."),
"input_len * (1 + range_ratio)]."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
)
sampling_group = parser.add_argument_group("sampling parameters")
@@ -1128,52 +1115,58 @@ if __name__ == "__main__":
"--top-p",
type=float,
default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--top-k",
type=int,
default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--min-p",
type=float,
default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible "
"backends.")
help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
)
sampling_group.add_argument(
"--temperature",
type=float,
default=None,
help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).")
"decoding (i.e. temperature==0.0).",
)
parser.add_argument(
'--tokenizer-mode',
"--tokenizer-mode",
type=str,
default="auto",
choices=['auto', 'slow', 'mistral', 'custom'],
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name",
parser.add_argument(
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")
"same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules",
nargs='+',
parser.add_argument(
"--lora-modules",
nargs="+",
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")
"script chooses a LoRA module at random.",
)
args = parser.parse_args()

View File

@@ -57,5 +57,3 @@ paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
num_experts);
return token_nums_per_expert;
}

View File

@@ -14,9 +14,10 @@
"""read_ids"""
import os
import numpy as np
import struct
import numpy as np
def deserialize_from_file(fp):
"""deserialize from file"""

View File

@@ -13,9 +13,10 @@
# limitations under the License.
"""read temp_ids from file"""
import os
import numpy as np
import struct
import numpy as np
def deserialize_from_file(fp):
"""

View File

@@ -118,4 +118,3 @@ void CUDART_CB RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_que
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
}

View File

@@ -41,8 +41,7 @@ ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy,
# which is not installed yet
envs = load_module_from_path('envs',
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
@@ -143,8 +142,7 @@ def get_nvcc_version():
"""
Get cuda version of nvcc.
"""
nvcc_output = subprocess.check_output(["nvcc", "--version"],
universal_newlines=True)
nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0])
@@ -160,13 +158,19 @@ def get_gencode_flags(archs):
for cc_val in cc_s:
if cc_val == 90:
arch_code = "90a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"]
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags
@@ -302,8 +306,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git",
cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
if not os.listdir(cutlass_dir):
raise ValueError("Git clone cutlass failed!")
@@ -312,8 +315,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
if not os.path.exists(deep_gemm_dir):
os.makedirs(deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git",
deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
if not os.listdir(deep_gemm_dir):
raise ValueError("Git clone DeepGEMM failed!")
cur_path = os.path.dirname(os.path.abspath(__file__))
@@ -347,15 +349,13 @@ elif paddle.is_compiled_with_cuda():
try:
shutil.copytree(src_dir, dst_dir)
except Exception as e:
raise RuntimeError(
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git",
json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
@@ -372,7 +372,7 @@ elif paddle.is_compiled_with_cuda():
"-Ithird_party/nlohmann_json/include",
]
nvcc_version = get_nvcc_version()
print(f'nvcc_version = {nvcc_version}')
print(f"nvcc_version = {nvcc_version}")
if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(get_sm_version(archs))
@@ -414,9 +414,7 @@ elif paddle.is_compiled_with_cuda():
# Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
@@ -431,14 +429,9 @@ elif paddle.is_compiled_with_cuda():
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
]
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
@@ -473,14 +466,12 @@ elif paddle.is_compiled_with_cuda():
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
# Common FP8 sources for SM89+
sources += [
@@ -493,7 +484,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu"
"gpu_ops/fused_hadamard_quant_fp8.cu",
]
sources += find_end_files(fp8_auto_gen_directory, ".cu")

View File

@@ -27,7 +27,8 @@ setup(
"cpu_ops/rebuild_padding.cc",
],
extra_compile_args=[
"-DPy_LIMITED_API=0x03090000", "-DPADDLE_ON_INFERENCE"
"-DPy_LIMITED_API=0x03090000",
"-DPADDLE_ON_INFERENCE",
],
),
)

View File

@@ -26,8 +26,7 @@ ROOT_DIR = Path(__file__).parent.parent
# which is not installed yet
from .setup_ops import load_module_from_path
envs = load_module_from_path('envs',
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
BUILDING_ARCS = []
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -48,8 +48,7 @@ def get_candidate_configs(sm):
candidate_configs = list()
hasbias = ("false", "true")
KernelSchedule = (
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", )
KernelSchedule = ("KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>",)
EpilogueSchedule = ("TmaWarpSpecializedCooperative",)
TileSchedule = ("PersistentScheduler", "StreamKScheduler")
for act_tag in [
@@ -57,8 +56,18 @@ def get_candidate_configs(sm):
# ("relu", "ReLu"),
# ("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule, TileSchedule)])
candidate_configs.extend(
[
(
hasbias,
act_tag,
tiles,
KernelSchedule,
EpilogueSchedule,
TileSchedule,
)
]
)
return candidate_configs
@@ -66,16 +75,13 @@ def get_shape_str(tile_shape):
"""
return tile_shape string.
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule,
tile_schedule):
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, tile_schedule):
"""
check the cutlass config valid.
"""
@@ -304,13 +310,10 @@ def SubstituteTemplate(template, values_base):
SubstituteTemplate
"""
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
@@ -329,8 +332,7 @@ def parse_args():
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -346,14 +348,14 @@ def parse_args():
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
TileSchedule: (str),
inputs_type: str,
outputs_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
TileSchedule: str,
sm: str,
):
"""
@@ -369,8 +371,11 @@ def generate_source_cu(
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(
tile_config, kernel_schedule,
epilogue_schedule, tile_schedule):
tile_config,
kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue
value_dict = {
"input_type": input_type,
@@ -385,30 +390,32 @@ def generate_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
generate_dir: str,
inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
"""
generate_launch_gemm_cus
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
TileSchedule: str = single_config[5]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_path = os.path.join(generate_dir, f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
@@ -418,19 +425,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_config, kernel_schedule,
if not check_config_valid(
tile_config,
kernel_schedule,
epilogue_schedule,
tile_schedule):
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
head_all_code += SubstituteTemplate(
LaunchGemmDeclare, value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -444,19 +451,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule):
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
source_all_code = SubstituteTemplate(
LaunchGemmPart0, value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
@@ -476,16 +483,14 @@ def generate_launch_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
gemm_config_str = gemm_config_str.replace("<", "").replace(
">", "")
gemm_config_str = gemm_config_str.replace("<", "").replace(">", "")
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
@@ -495,19 +500,18 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
"""
generate_dispatch_gemm_cu
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
TileSchedule: (str) = single_config[5]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
TileSchedule: str = single_config[5]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
for input_type in inputs_type:
@@ -530,9 +534,12 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule):
tile_schedule,
):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -554,18 +561,18 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
if not check_config_valid(
tile_shape,
kernel_schedule,
epilogue_schedule,
tile_schedule):
tile_schedule,
):
continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = {
"sm":
sm[-2:],
"tile_id":
str(tile_id),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
"sm": sm[-2:],
"tile_id": str(tile_id),
"gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
@@ -610,12 +617,17 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/"
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu")
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -24,7 +24,8 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
@@ -38,13 +39,13 @@ def get_candidate_tiles():
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
]
)
return base_configs
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
"""
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
"""
@@ -299,8 +300,7 @@ def check_min_split_k(value):
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support.")
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
return ivalue
@@ -310,8 +310,7 @@ def check_max_split_k(value):
"""
ivalue = int(value)
if ivalue > 1:
raise argparse.ArgumentTypeError(
"Dual gemm split_k mode is not support..")
raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
return ivalue
@@ -320,8 +319,7 @@ def parse_args():
parse_args
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -421,8 +419,7 @@ def generate_dual_gemm_source_cu(
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(
GemmSplitKDeclare, value_dict)
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
all_code += CommonTail
return all_code
@@ -449,12 +446,12 @@ def generate_launch_dual_gemm_cus(
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -467,12 +464,12 @@ def generate_launch_dual_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -498,16 +495,14 @@ def generate_launch_dual_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
# source_all_code += split_k_code
# source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -566,12 +561,12 @@ def generate_dispatch_dual_gemm_cu(
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = (
f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}")
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
)
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = {
@@ -580,10 +575,12 @@ def generate_dispatch_dual_gemm_cu(
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
value_dict.update(
{
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
}
)
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
@@ -602,8 +599,7 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_dual_gemm_candidate_configs(
sm, min_split_k, max_split_k, min_stages, max_stages)
fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"

View File

@@ -19,8 +19,7 @@ import re
def get_candidate_tiles():
"""
"""
""" """
cta_shape = [
("<_64, _16, _128>"),
("<_64, _32, _128>"),
@@ -45,8 +44,7 @@ def get_candidate_tiles():
def get_dual_gemm_candidate_configs(sm):
"""
"""
""" """
tiles = get_candidate_tiles()
candidate_configs = list()
@@ -64,35 +62,27 @@ def get_dual_gemm_candidate_configs(sm):
("swiglu", "SiLu"),
("geglu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
""" """
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
""" """
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if tile_shape[
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
if tile_shape[0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
return False
return True
@@ -302,8 +292,7 @@ bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
def SubstituteTemplate(template, values):
"""
"""
""" """
text = template
changed = True
while changed:
@@ -318,10 +307,8 @@ def SubstituteTemplate(template, values):
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
""" """
parser = argparse.ArgumentParser(description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
@@ -336,17 +323,16 @@ def parse_args():
# generate source .cu
def generate_dual_gemm_source_cu(
inputs_type: (str),
biases_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
inputs_type: str,
biases_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
sm: str,
):
"""
"""
""" """
all_code = CommonHead
for input_type in inputs_type:
for bias_type in biases_type:
@@ -354,9 +340,7 @@ def generate_dual_gemm_source_cu(
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
@@ -370,28 +354,29 @@ def generate_dual_gemm_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_dual_gemm_cus(
generate_dir: (str), inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
generate_dir: str,
inputs_type: str,
biases_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config)
@@ -401,16 +386,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -422,16 +405,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for bias_type in biases_type:
@@ -450,14 +431,13 @@ def generate_launch_dual_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu"
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
@@ -467,16 +447,14 @@ def generate_launch_dual_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
def generate_dispatch_dual_gemm_cu(inputs_type: str, biases_type: str, fuse_gemm_configs: tuple, sm: str):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
@@ -500,8 +478,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -520,8 +497,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
@@ -570,12 +546,15 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
biases_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
biases_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_dual_gemm_cu(
inputs_type,

View File

@@ -31,7 +31,8 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
base_configs.extend(
[
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
@@ -43,13 +44,13 @@ def get_candidate_tiles():
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
]
)
return base_configs
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages,
max_stages):
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
"""
获取候选的gemm算子配置列表。
@@ -353,8 +354,7 @@ def parse_args():
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -448,8 +448,7 @@ def generate_source_cu(
"hasbias": hasbias,
"SM": sm,
}
all_code += SubstituteTemplate(GemmSplitKDeclare,
value_dict)
all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
all_code += CommonTail
return all_code
@@ -473,9 +472,7 @@ def generate_launch_gemm_cus(
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -489,9 +486,7 @@ def generate_launch_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -517,17 +512,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
split_k_code += SubstituteTemplate(
LaunchGemmPart3, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
source_all_code += split_k_code
source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -581,9 +573,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -593,10 +583,12 @@ def generate_dispatch_gemm_cu(
}
all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1
value_dict.update({
value_dict.update(
{
"min_split_k": str(min_split_k),
"max_split_k": str(max_split_k),
})
}
)
all_code += SubstituteTemplate(code_part6, value_dict)
return all_code
@@ -614,9 +606,7 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_split_k,
max_split_k, min_stages,
max_stages)
fuse_gemm_configs = get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
all_code = generate_source_cu(
@@ -654,9 +644,7 @@ if __name__ == "__main__":
# hard code for act_tag
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
)
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -20,14 +20,14 @@ import re
def get_candidate_tiles():
"""
"""
""" """
base_configs = [
("<_64, _64, _128>", "<_1, _8, _1>"),
("<_64, _128, _128>", "<_2, _1, _1>"),
("<_128, _128, _128>", "<_2, _1, _1>"),
]
base_configs.extend([
base_configs.extend(
[
("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"),
@@ -50,14 +50,14 @@ def get_candidate_tiles():
# ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"),
])
]
)
return base_configs
def get_candidate_configs(sm):
"""
"""
""" """
tiles = get_candidate_tiles()
candidate_configs = list()
@@ -73,36 +73,31 @@ def get_candidate_configs(sm):
("relu", "ReLu"),
("gelu", "GELU"),
]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule,
EpilogueSchedule)])
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
return candidate_configs
def get_shape_str(tile_shape):
"""
"""
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
""" """
blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
"""
"""
""" """
blocks, clusters = get_shape_str(tile_shape)
if int(
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False
if (tile_shape[0] == "<_256, _128, _128>"
if (
tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule):
and "Cooperative" not in epilogue_schedule
):
return False
return True
@@ -321,16 +316,12 @@ bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
def SubstituteTemplate(template, values_base):
"""
"""
""" """
values = copy.deepcopy(values_base)
if values.get("KernelSchedule"
) is not None and "Auto" in values["KernelSchedule"]:
if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule"
) is not None and "Auto" in values["EpilogueSchedule"]:
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template
changed = True
while changed:
@@ -345,10 +336,8 @@ def SubstituteTemplate(template, values_base):
def parse_args():
"""
"""
parser = argparse.ArgumentParser(
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
""" """
parser = argparse.ArgumentParser(description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser.add_argument(
"--cuda_arch",
type=str,
@@ -363,17 +352,16 @@ def parse_args():
# generate source .cu
def generate_source_cu(
inputs_type: (str),
outputs_type: (str),
hasbiases: (str),
act_tag: (str),
tiles: (str),
KernelSchedule: (str),
EpilogueSchedule: (str),
inputs_type: str,
outputs_type: str,
hasbiases: str,
act_tag: str,
tiles: str,
KernelSchedule: str,
EpilogueSchedule: str,
sm: str,
):
"""
"""
""" """
all_code = CommonHead
for input_type in inputs_type:
@@ -382,9 +370,7 @@ def generate_source_cu(
for tile_config in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config,
kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"input_type": input_type,
@@ -398,25 +384,27 @@ def generate_source_cu(
"SM": sm,
"sm": sm[-2:],
}
all_code += SubstituteTemplate(
GemmDeclare, value_dict)
all_code += SubstituteTemplate(GemmDeclare, value_dict)
return all_code
# generate gemm launch .cu
def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
generate_dir: str,
inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
code_map = {}
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead
@@ -426,16 +414,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
head_all_code += SubstituteTemplate(LaunchGemmDeclare,
value_dict)
head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f:
f.write(head_all_code)
@@ -447,16 +433,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
"sm": sm[-2:],
"gemm_config": gemm_config_str,
}
source_all_code = SubstituteTemplate(LaunchGemmPart0,
value_dict)
source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
type_id = 0
for input_type in inputs_type:
for output_type in outputs_type:
@@ -475,14 +459,14 @@ def generate_launch_gemm_cus(
"SM": sm,
"sm": sm[-2:],
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu")
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -491,17 +475,15 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
fuse_gemm_configs: tuple, sm: str):
"""
"""
def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0]
tiles: (str) = single_config[2]
KernelSchedule: (str) = single_config[3]
EpilogueSchedule: (str) = single_config[4]
hasbiases: str = single_config[0]
tiles: str = single_config[2]
KernelSchedule: str = single_config[3]
EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0
@@ -524,8 +506,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for tile_shape in tiles:
for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
value_dict = {
"TileShape": tile_shape[0],
@@ -544,8 +525,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule,
epilogue_schedule):
if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = {
@@ -576,7 +556,8 @@ if __name__ == "__main__":
for fuse_gemm_config in fuse_gemm_configs:
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu")
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_source_cu(
inputs_type,
outputs_type,
@@ -594,8 +575,12 @@ if __name__ == "__main__":
f.close()
# Compile parallelization
generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type,
outputs_type, fuse_gemm_configs, sm_dict[sm])
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"

View File

@@ -30,7 +30,8 @@ def get_candidate_tiles():
"""
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([
base_configs.extend(
[
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
@@ -45,7 +46,8 @@ def get_candidate_tiles():
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
])
]
)
return base_configs
@@ -278,8 +280,7 @@ def parse_args():
代码参数解析
"""
parser = argparse.ArgumentParser(
description=
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
)
parser.add_argument(
"--cuda_arch",
@@ -370,13 +371,10 @@ def generate_launch_gemm_cus(
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。
"""
code_map = {}
head_path = os.path.join(generate_dir,
"launch_visitor_gemm_fused_kernel.h")
head_path = os.path.join(generate_dir, "launch_visitor_gemm_fused_kernel.h")
head_all_code = LaunchGemmHead
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -390,9 +388,7 @@ def generate_launch_gemm_cus(
f.close()
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -415,14 +411,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage),
"SM": sm,
}
source_all_code += SubstituteTemplate(
LaunchGemmPart1, value_dict)
source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
type_id += 1
source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code
source_path = os.path.join(
generate_dir,
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu")
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu",
)
with open(source_path, "w") as f:
f.write(source_all_code)
f.close()
@@ -485,9 +481,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4
tile_id = 0
for tile in tiles:
blocks, warps, mmas = [
s.replace(" ", "").strip("<>").split(",") for s in tile
]
blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}"
@@ -512,10 +506,11 @@ if __name__ == "__main__":
for sm in archs:
if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_stages,
max_stages)
fuse_gemm_configs = get_candidate_configs(sm, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
)
all_code = generate_source_cu(
inputs_type,
outputs_type,
@@ -544,9 +539,7 @@ if __name__ == "__main__":
sm_dict[sm],
)
file_name = (
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
)
file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
all_code = generate_dispatch_gemm_cu(
inputs_type,
outputs_type,

View File

@@ -30,8 +30,7 @@ current_file = Path(__file__).resolve()
base_dir = current_file.parent
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
XDNN_LIB_DIR):
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
"""
build xpu plugin
"""
@@ -49,7 +48,10 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 删除指定目录
dirs_to_remove = [
"dist", "fastdeploy_ops.egg-info", "build", "plugin/build"
"dist",
"fastdeploy_ops.egg-info",
"build",
"plugin/build",
]
for dir_name in dirs_to_remove:
if os.path.exists(dir_name):
@@ -58,8 +60,7 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 在 plugin 目录中执行构建脚本
plugin_dir = "plugin"
build_script = os.path.join(current_working_directory, plugin_dir,
"build.sh")
build_script = os.path.join(current_working_directory, plugin_dir, "build.sh")
print("build_script: ", build_script)
@@ -74,14 +75,16 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 执行构建脚本
try:
print("Running build script...")
subprocess.run([build_script],
subprocess.run(
[build_script],
check=True,
cwd=os.path.join(current_working_directory, plugin_dir))
cwd=os.path.join(current_working_directory, plugin_dir),
)
print("Build completed successfully.")
except subprocess.CalledProcessError as e:
print(f"Build failed with error: {e}")
except Exception as e:
print(f"Unexpected error: {str(e)}")
print(f"Unexpected error: {e!s}")
def xpu_setup_ops():
@@ -124,17 +127,14 @@ def xpu_setup_ops():
XVLLM_PATH = os.getenv("XVLLM_PATH")
assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so",
"libapiinfer.so")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", "libapiinfer.so")
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so",
"libxft_blocks.so")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", "libxft_blocks.so")
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
# build plugin
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH,
XDNN_LIB_DIR)
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
ops = [
# custom ops
@@ -152,7 +152,6 @@ def xpu_setup_ops():
"./ops/block_attn.cc",
"./ops/moe_layer.cc",
"./ops/weight_quantize_xpu.cc",
# device manage ops
"./ops/device/get_context_gm_max_mem_demand.cc",
"./ops/device/get_free_global_memory.cc",

View File

@@ -29,7 +29,7 @@ for i in range(bs):
ids_len = seq_lens[i, 0]
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
(x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num),
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
print("cu_seqlens_q:\n", cu_seqlens_q)
print("cu_seqlens_k:\n", cu_seqlens_k)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6],
"int64")
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
ref_cum_offsets_out = np.array([0, 6, 13], "int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13],
"int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
assert sum(ref_x_remove_padding -
x_remove_padding) == 0, 'Check x_remove_padding failed.'
assert sum(ref_cum_offsets_out -
cum_offsets_out) == 0, 'Check cum_offsets_out failed.'
assert sum(ref_padding_offset -
padding_offset) == 0, 'Check padding_offset failed.'
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'
assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."

View File

@@ -21,10 +21,15 @@ paddle.seed(2023)
pre_ids = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64")
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]],
"float32")
"int64",
)
logits = paddle.to_tensor(
[
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
],
"float32",
)
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
@@ -88,78 +93,536 @@ ref_logits = np.array(
)
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.'
assert diff_logits < 1e-6, "Check failed."
pre_ids = paddle.to_tensor(
[[
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8,
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4,
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6,
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8,
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4,
7, 5, 5, 7, 6, 9, 3, 9
[
[
2,
3,
3,
5,
8,
9,
3,
9,
1,
8,
9,
2,
3,
8,
8,
9,
9,
1,
4,
2,
6,
2,
6,
8,
7,
2,
2,
3,
8,
1,
5,
7,
9,
2,
2,
9,
1,
4,
9,
8,
5,
8,
5,
7,
3,
6,
4,
4,
9,
9,
8,
5,
5,
2,
2,
9,
4,
8,
1,
9,
6,
9,
2,
2,
7,
2,
2,
9,
4,
6,
4,
6,
1,
4,
1,
9,
1,
8,
8,
5,
7,
9,
4,
2,
5,
1,
1,
4,
1,
5,
5,
4,
4,
2,
1,
8,
7,
1,
2,
9,
6,
7,
9,
6,
7,
7,
4,
9,
9,
7,
5,
1,
8,
9,
8,
8,
5,
4,
6,
4,
7,
5,
5,
7,
6,
9,
3,
9,
],
[
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1,
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1,
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8,
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6,
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6,
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8
]])
7,
8,
1,
3,
1,
7,
6,
3,
5,
3,
8,
3,
1,
9,
7,
1,
1,
9,
5,
4,
9,
6,
1,
9,
3,
8,
3,
9,
9,
6,
4,
2,
8,
5,
3,
1,
6,
9,
1,
3,
9,
8,
1,
7,
5,
1,
5,
1,
8,
7,
4,
5,
9,
8,
7,
4,
7,
3,
6,
4,
6,
6,
5,
5,
2,
9,
9,
5,
8,
8,
4,
8,
2,
8,
1,
3,
9,
1,
8,
5,
8,
3,
8,
8,
2,
7,
3,
7,
5,
7,
2,
6,
3,
5,
1,
4,
6,
1,
9,
8,
2,
2,
3,
6,
7,
6,
2,
6,
5,
1,
5,
6,
2,
1,
6,
4,
7,
7,
3,
8,
5,
1,
9,
1,
2,
8,
6,
8,
],
]
)
logits = paddle.to_tensor(
[[
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748,
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807,
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218,
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679,
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888,
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911,
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448,
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415,
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104,
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064,
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672,
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840,
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613,
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254,
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189,
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407,
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923,
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550,
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222,
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845,
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070,
0.98481447, 0.77367836
[
[
0.16274983,
0.61470598,
0.94366980,
0.82005417,
0.50752640,
0.38316748,
0.92648441,
0.24050158,
0.05461595,
0.42218581,
0.36270225,
0.15464807,
0.13614719,
0.67509544,
0.40315166,
0.10671722,
0.24832056,
0.76091218,
0.11598995,
0.10962527,
0.04688513,
0.81536716,
0.72259802,
0.60476679,
0.16701800,
0.84160781,
0.79649884,
0.78021604,
0.75329530,
0.98587888,
0.13421868,
0.16027625,
0.15269397,
0.06228730,
0.73856270,
0.34721911,
0.73683006,
0.78178608,
0.32068327,
0.79906309,
0.44214272,
0.63330448,
0.08016958,
0.63367140,
0.19788943,
0.55346787,
0.11142531,
0.90518415,
0.21236691,
0.81587470,
0.83752930,
0.70979482,
0.35684183,
0.28715104,
0.87162822,
0.17679396,
0.98725849,
0.76129991,
0.04090235,
0.37181064,
0.63317049,
0.24689502,
0.21126501,
0.57617670,
0.74346697,
0.40613672,
0.56907010,
0.68556929,
0.29032683,
0.17866278,
0.35165095,
0.97015840,
0.70785582,
0.54259878,
0.14712237,
0.90483177,
0.02094105,
0.36411613,
0.02495066,
0.88874054,
0.88895452,
0.86216462,
0.58062190,
0.95583254,
0.20553111,
0.29870346,
0.69652933,
0.36861244,
0.85316223,
0.50240189,
0.17566244,
0.61080140,
0.88203174,
0.98675215,
0.24344546,
0.17213407,
0.78160852,
0.25165486,
0.48188508,
0.82812423,
0.10199814,
0.90475923,
0.66907483,
0.71910626,
0.40660757,
0.59460294,
0.70212913,
0.90841550,
0.00329034,
0.11290466,
0.89654654,
0.69114941,
0.29473618,
0.62027222,
0.37333879,
0.98911142,
0.46510187,
0.65914583,
0.73022646,
0.12790845,
0.12817244,
0.43015456,
0.75011456,
0.43562204,
0.48086026,
0.75587070,
0.98481447,
0.77367836,
],
[
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417,
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683,
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915,
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503,
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967,
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547,
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396,
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451,
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258,
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694,
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542,
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293,
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045,
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974,
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888,
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515,
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780,
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544,
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529,
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276,
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621,
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523,
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239,
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721,
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736,
0.87317258, 0.88229448, 0.79217333
]])
0.12336024,
0.74152875,
0.09191196,
0.99301219,
0.44764417,
0.01848883,
0.78326035,
0.99228370,
0.81447607,
0.02627683,
0.51033205,
0.98703283,
0.15247856,
0.77640921,
0.60799915,
0.87518770,
0.76818430,
0.86542630,
0.31795895,
0.04829503,
0.85567141,
0.30271924,
0.67515039,
0.59728831,
0.78710967,
0.75111693,
0.56837374,
0.49085775,
0.91510201,
0.59545547,
0.99482232,
0.59036905,
0.58267909,
0.28770933,
0.53237396,
0.95318258,
0.93987304,
0.61142951,
0.26737869,
0.52285451,
0.03479086,
0.61631846,
0.66777998,
0.15736090,
0.00447258,
0.37035006,
0.15281211,
0.95372260,
0.25963321,
0.61036694,
0.15020694,
0.19171195,
0.55252832,
0.00391038,
0.31052542,
0.96495175,
0.42586124,
0.05630261,
0.99728668,
0.01856293,
0.83201504,
0.10701843,
0.56434178,
0.38009524,
0.51095045,
0.13202040,
0.07133843,
0.75313550,
0.17111187,
0.80716974,
0.00172165,
0.83906764,
0.73240769,
0.85843354,
0.11042888,
0.07912333,
0.33689004,
0.22334915,
0.59059596,
0.52789515,
0.29831955,
0.39515004,
0.55602801,
0.83818001,
0.05865780,
0.25654668,
0.76624149,
0.35190639,
0.04158346,
0.59157544,
0.30779791,
0.94609004,
0.10759670,
0.65575141,
0.37828529,
0.29571742,
0.76361233,
0.72476572,
0.18568406,
0.85430276,
0.02057583,
0.76195669,
0.65507215,
0.69129735,
0.25084621,
0.75223947,
0.06064088,
0.20287007,
0.35887691,
0.75043523,
0.47575447,
0.40021798,
0.44464844,
0.67975360,
0.40443239,
0.71052992,
0.21782248,
0.50568426,
0.89037591,
0.06661721,
0.28788096,
0.70773387,
0.42428264,
0.80419677,
0.42710736,
0.87317258,
0.88229448,
0.79217333,
],
]
)
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
print("eos_token_id\n", eos_token_id)
ref_logits = np.array(
[[
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280,
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450,
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113,
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603,
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060,
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541,
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545,
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062,
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366,
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470,
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395,
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191,
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210,
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380,
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446,
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091,
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629,
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825,
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235,
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291,
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053,
1.51174140, 1.96962893, 1.54735672
[
[
-10000000000.0,
-10000000000.0,
1.88733959,
1.64010835,
1.01505280,
0.76633495,
1.85296881,
0.48100317,
0.10923190,
0.84437162,
0.72540450,
0.30929613,
0.27229437,
1.35019088,
0.80630332,
0.21343444,
0.49664113,
1.52182436,
0.23197991,
0.21925054,
0.09377026,
1.63073432,
1.44519603,
1.20953357,
0.33403599,
1.68321562,
1.59299767,
1.56043208,
1.50659060,
1.97175777,
0.26843736,
0.32055250,
0.30538794,
0.12457460,
1.47712541,
0.69443822,
1.47366011,
1.56357217,
0.64136654,
1.59812617,
0.88428545,
1.26660895,
0.16033916,
1.26734281,
0.39577886,
1.10693574,
0.22285062,
1.81036830,
0.42473382,
1.63174939,
1.67505860,
1.41958964,
0.71368366,
0.57430208,
1.74325645,
0.35358793,
1.97451699,
1.52259982,
0.08180470,
0.74362129,
1.26634097,
0.49379003,
0.42253003,
1.15235341,
1.48693395,
0.81227344,
1.13814020,
1.37113857,
0.58065367,
0.35732555,
0.70330191,
1.94031680,
1.41571164,
1.08519757,
0.29424474,
1.80966353,
0.04188210,
0.72823226,
0.04990132,
1.77748108,
1.77790904,
1.72432923,
1.16124380,
1.91166508,
0.41106221,
0.59740692,
1.39305866,
0.73722488,
1.70632446,
1.00480378,
0.35132489,
1.22160280,
1.76406348,
1.97350430,
0.48689091,
0.34426814,
1.56321704,
0.50330973,
0.96377015,
1.65624845,
0.20399629,
1.80951846,
1.33814967,
1.43821251,
0.81321514,
1.18920588,
1.40425825,
1.81683099,
0.00658068,
0.22580932,
1.79309309,
1.38229883,
0.58947235,
1.24054444,
0.74667758,
1.97822285,
0.93020374,
1.31829166,
1.46045291,
0.25581691,
0.25634488,
0.86030912,
1.50022912,
0.87124407,
0.96172053,
1.51174140,
1.96962893,
1.54735672,
],
[
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670,
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819,
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078,
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562,
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773,
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929,
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031,
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343,
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025,
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775,
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701,
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015,
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160,
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660,
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333,
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819,
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670,
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164,
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967,
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332,
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786,
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788,
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969,
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383,
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032,
3.52917790, 3.16869330
]],
-10000000000.0,
-10000000000.0,
-40000.0,
3.97204876,
1.79057670,
0.07395532,
3.13304138,
3.96913481,
3.25790429,
-40000.0,
2.04132819,
3.94813132,
0.60991424,
3.10563684,
2.43199658,
3.50075078,
3.07273722,
3.46170521,
1.27183580,
0.19318011,
3.42268562,
1.21087694,
2.70060158,
2.38915324,
3.14843869,
3.00446773,
2.27349496,
1.96343100,
3.66040802,
2.38182187,
3.97928929,
2.36147618,
2.33071637,
1.15083730,
2.12949586,
3.81273031,
3.75949216,
2.44571805,
1.06951475,
2.09141803,
0.13916343,
2.46527386,
2.67111993,
0.62944359,
0.01789032,
1.48140025,
0.61124843,
3.81489038,
1.03853285,
2.44146776,
0.60082775,
0.76684779,
2.21011329,
0.01564152,
1.24210167,
3.85980701,
1.70344496,
0.22521044,
3.98914671,
0.07425172,
3.32806015,
0.42807373,
2.25736713,
1.52038097,
2.04380178,
0.52808160,
0.28535372,
3.01254201,
0.68444747,
3.22867894,
0.00688660,
3.35627055,
2.92963076,
3.43373418,
0.44171551,
0.31649333,
1.34756017,
0.89339662,
2.36238384,
2.11158061,
1.19327819,
1.58060014,
2.22411203,
3.35272002,
0.23463120,
1.02618670,
3.06496596,
1.40762556,
0.16633384,
2.36630177,
1.23119164,
3.78436017,
0.43038681,
2.62300563,
1.51314116,
1.18286967,
3.05444932,
2.89906287,
0.74273622,
3.41721106,
0.08230332,
3.04782677,
2.62028861,
2.76518941,
1.00338483,
3.00895786,
0.24256352,
0.81148028,
1.43550766,
3.00174093,
1.90301788,
1.60087192,
1.77859378,
2.71901441,
1.61772954,
2.84211969,
0.87128991,
2.02273703,
3.56150365,
0.26646885,
1.15152383,
2.83093548,
1.69713056,
3.21678710,
1.70842946,
3.49269032,
3.52917790,
3.16869330,
],
],
"float32",
)
diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.'
assert diff_logits < 1e-6, "Check failed."

View File

@@ -21,19 +21,30 @@ paddle.seed(2023)
pre_ids_all = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64")
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]],
"int64")
"int64",
)
input_ids = paddle.to_tensor(
[
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
],
"int64",
)
seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
step_idx = paddle.to_tensor([1, 1], "int64")
stop_flags = paddle.to_tensor([0, 1], "bool")
print("pre_ids_all\n", pre_ids_all)
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_idx,
stop_flags)
set_value_by_flags_and_idx(
pre_ids_all,
input_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
)
print("pre_ids_all\n", pre_ids_all)
print("input_ids\n", input_ids)
print("seq_lens_this_time\n", seq_lens_this_time)
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
)
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
print("diff_pre_ids_all\n", diff_pre_ids_all)
assert diff_pre_ids_all == 0, 'Check failed.'
assert diff_pre_ids_all == 0, "Check failed."

View File

@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(max_block_num - 1,
max_block_num - free_list_len - 1,
-1,
dtype="int32")
free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32")
@@ -53,19 +50,15 @@ encoder_block_id = 0
for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size -
1) // block_size - enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange(
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num
if dec_block_num > 0:
block_tables[
i, enc_block_num:enc_block_num +
dec_block_num] = free_list[free_list_len[0] - 1 -
dec_block_num:free_list_len[0] - 1]
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
1] = -1
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
]
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
# print("step_idx: ", step_idx)
# print("next_tokens: ", next_tokens)
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_lens,
recover_block_list, recover_lens, need_block_list, need_block_len,
used_list_len, free_list, free_list_len, input_ids, pre_ids,
step_idx, next_tokens, first_token_ids, block_size,
encoder_decoder_block_num)
step_paddle(
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
block_size,
encoder_decoder_block_num,
)
print("-" * 50 + "after step op" + "-" * 50)
print("stop_flags: ", stop_flags)

View File

@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
False)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array(
[
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20,
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39,
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59,
60, 0, 0, 63
0,
0,
2,
3,
-1,
0,
0,
0,
0,
9,
10,
0,
12,
0,
-1,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
-1,
29,
30,
31,
0,
0,
0,
-1,
-1,
37,
38,
39,
-1,
-1,
0,
0,
0,
0,
46,
-1,
0,
49,
50,
0,
52,
53,
0,
-1,
0,
57,
-1,
59,
60,
0,
0,
63,
],
"int64",
)
ref_next_tokens = np.array(
[
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20,
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0,
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0,
0, 63
0,
0,
2,
3,
0,
0,
0,
0,
0,
9,
10,
0,
12,
0,
0,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
0,
29,
30,
31,
0,
0,
0,
0,
0,
37,
38,
39,
0,
0,
0,
0,
0,
0,
46,
0,
0,
49,
50,
0,
52,
53,
0,
0,
0,
57,
0,
59,
60,
0,
0,
63,
],
"int64",
)
ref_stop_flags = np.array(
[
True, True, True, True, True, True, True, True, True, False, False,
True, False, True, True, False, False, True, False, False, False, True,
False, False, True, False, False, False, True, False, False, False,
True, True, True, True, True, False, False, False, True, True, True,
True, True, True, False, True, True, False, False, True, False, False,
True, True, True, False, True, False, False, True, True, False
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
True,
False,
True,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
False,
True,
True,
True,
True,
True,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
True,
True,
False,
False,
True,
False,
False,
True,
True,
True,
False,
True,
False,
False,
True,
True,
False,
],
"bool",
)
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.'
assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.'
diff_stop_flags = np.sum(
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.'
assert diff_stop_flags == 0, "Check failed."
# test beam_search=True
topk_ids = paddle.arange(0, bs, dtype="int64")
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens,
True)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags)
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array(
[
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19,
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0,
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58,
-1, 60, 61, -1, 63
0,
1,
2,
3,
4,
0,
6,
7,
-1,
9,
10,
0,
-1,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
-1,
-1,
28,
29,
0,
0,
-1,
33,
34,
35,
36,
37,
0,
-1,
0,
41,
-1,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
-1,
60,
61,
-1,
63,
],
"int64",
)
ref_next_tokens = np.array(
[
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20,
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0,
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61,
0, 63
0,
1,
2,
3,
4,
0,
6,
7,
0,
9,
10,
0,
0,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
0,
0,
28,
29,
0,
0,
0,
33,
34,
35,
36,
37,
0,
0,
0,
41,
0,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
0,
60,
61,
0,
63,
],
"int64",
)
ref_stop_flags = np.array(
[
False, False, False, False, False, True, False, False, True, False,
False, True, True, False, False, False, True, False, False, False,
False, True, False, False, False, False, True, True, False, False,
True, True, True, False, False, False, False, False, True, True, True,
False, True, True, False, False, False, True, True, False, True, True,
True, False, True, True, True, True, False, True, False, False, True,
False
False,
False,
False,
False,
False,
True,
False,
False,
True,
False,
False,
True,
True,
False,
False,
False,
True,
False,
False,
False,
False,
True,
False,
False,
False,
False,
True,
True,
False,
False,
True,
True,
True,
False,
False,
False,
False,
False,
True,
True,
True,
False,
True,
True,
False,
False,
False,
True,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
False,
True,
False,
False,
True,
False,
],
"bool",
)
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.'
assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.'
diff_stop_flags = np.sum(
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.'
assert diff_stop_flags == 0, "Check failed."

View File

@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step)
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder,
seq_lens_decoder, input_ids, stop_nums, next_tokens,
is_block_step)
update_inputs(
stop_flags,
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
next_tokens,
is_block_step,
)
print("-" * 50)
print("stop_flags:\n", stop_flags)
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens)
ref_not_need_stop_out = np.array([True])
ref_seq_lens_this_time_out = np.array([
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1,
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1
], "int32")
ref_seq_lens_encoder_out = np.array([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], "int32")
ref_seq_lens_decoder_out = np.array([
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0,
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0,
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], "int32")
input_ids_np[:, 0] = np.array([
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5,
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7,
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8
], "int64")
ref_seq_lens_this_time_out = np.array(
[
0,
0,
1,
0,
0,
1,
0,
1,
1,
1,
0,
1,
1,
0,
0,
0,
0,
0,
0,
0,
1,
1,
0,
1,
1,
0,
1,
1,
0,
0,
0,
1,
1,
0,
1,
0,
0,
1,
0,
1,
0,
0,
1,
0,
0,
1,
1,
1,
],
"int32",
)
ref_seq_lens_encoder_out = np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
ref_seq_lens_decoder_out = np.array(
[
0,
0,
2,
0,
0,
6,
0,
8,
8,
10,
0,
12,
12,
0,
0,
0,
0,
0,
0,
0,
20,
22,
0,
24,
24,
0,
26,
28,
0,
0,
0,
32,
32,
0,
34,
0,
0,
38,
0,
40,
0,
0,
42,
0,
0,
46,
46,
48,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
input_ids_np[:, 0] = np.array(
[
6,
5,
9,
8,
6,
2,
8,
1,
3,
1,
3,
6,
9,
8,
1,
9,
1,
8,
8,
6,
7,
6,
5,
3,
5,
9,
3,
6,
3,
9,
8,
8,
8,
8,
4,
8,
7,
4,
2,
3,
5,
8,
4,
2,
5,
6,
8,
9,
6,
7,
4,
2,
4,
6,
2,
3,
4,
9,
7,
2,
1,
8,
7,
8,
],
"int64",
)
assert not_need_stop.numpy(
) == ref_not_need_stop_out, 'Check not_need_stop failed.'
assert np.all(seq_lens_this_time.numpy() ==
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.'
assert np.all(seq_lens_encoder.numpy() ==
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
assert np.all(seq_lens_decoder.numpy() ==
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'
assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."

View File

@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
weight = np.transpose(weight_np, [1, 0]) # n,k
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32)
def np_quant_weight(weight_np, algo='weight_only_int8'):
def np_quant_weight(weight_np, algo="weight_only_int8"):
assert weight_np.dtype == np.float32
if algo == 'weight_only_int4':
if algo == "weight_only_int4":
return np_quant_weight_int4(weight_np)
weight = np.transpose(weight_np, [1, 0])
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
def int8_to_bin(value):
if not -128 <= value <= 127:
raise ValueError("int8 值必须在 -128 到 127 之间")
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零
return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
# 1) preparation
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
qw_np, wscale_np = np_quant_weight(w_np, algo)
# 3) xpu calculation
dtype = 'float32'
dtype = "float32"
x_pd = paddle.to_tensor(w_np, dtype=dtype)
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
# comparation
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
print(
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}"
)
print(
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
)
sum_diff = np.sum(
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"sum_diff: {sum_diff}")

View File

@@ -72,7 +72,8 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
### Multi-machine Disaggregated Deployment
#### Prerequisite: Redis
- Installation via `conda`
* Installation via `conda`
```bash
# Install
conda install redis
@@ -80,7 +81,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 &
```
- Installation via `apt`
* Installation via `apt`
```bash
# Install
sudo apt install redis-server -y
@@ -88,7 +90,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server
```
- Installation via `yum`
* Installation via `yum`
```bash
# Install
sudo yum install redis -y

View File

@@ -38,6 +38,7 @@ conda install redis
# Launch
nohup redis-server > redis.log 2>&1 &
```
### apt installation (Debian/Ubuntu)
```bash
@@ -57,6 +58,7 @@ sudo systemctl start redis
```
## Launching FastDeploy
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \
@@ -72,6 +74,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \
--scheduler-load-shards-num 1
```
[Scheduler Launching Parameter](../online_serving/scheduler.md)
### Deployment notes:

View File

@@ -20,6 +20,7 @@ For reasoning models, the length of the reasoning content can be controlled via
### Quick Start
When launching the model service, specify the parser name using the `--reasoning-parser` argument.
This parser will process the model's output and extract the `reasoning_content` field.
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \
@@ -29,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--reasoning-parser ernie-45-vl
```
Next, make a request to the model that should return the reasoning content in the response.
```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \
@@ -43,10 +46,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
"metadata": {"enable_thinking": true}
}'
```
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
### Streaming chat completions
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
```python
from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.

View File

@@ -132,4 +132,3 @@ Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
```

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
```
## 2. Start service
```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8
```
#### Send requests
### Send requests
Send requests using either curl or Python

View File

@@ -19,13 +19,15 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
## Container Preparation
1. Start Container
```bash
docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker exec -it paddle_infer bash
```
/home/paddle contains the model files, *.whl packages, and scripts.
2. Install packages
1. Install packages
```bash
pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
@@ -38,6 +40,7 @@ pip3 install fastdeploy_iluvatar_gpu -i https://www.paddlepaddle.org.cn/packages
script list below:
`run_demo.sh`:
```bash
#!/bin/bash
export PADDLE_XCCL_BACKEND=iluvatar_gpu
@@ -78,7 +81,9 @@ for output in outputs:
```bash
./run_demo.sh
```
The following logs will be printed: Loading the model took approximately 74 seconds, and running the demo took approximately 240 seconds.
```
/usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
warnings.warn(warning_message)

View File

@@ -36,6 +36,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
]
}'
```
Here's an example curl command demonstrating how to include the logprobs parameter in a user request:
```bash
@@ -49,6 +50,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
```
Here is an example of sending a user request using a Python script:
```python
import openai
host = "0.0.0.0"

View File

@@ -45,7 +45,6 @@ When using FastDeploy to deploy models (including offline inference and service
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
During FastDeploy inference, GPU memory is occupied by ```model weights```, ```preallocated KVCache blocks``` and ```model computation intermediate activation values```. The preallocated KVCache blocks are determined by ```num_gpu_blocks_override```, with ```block_size``` (default: 64) as its unit, meaning one block can store KVCache for 64 Tokens.
@@ -80,18 +79,18 @@ Currently, only user configuration of the following parameters is supported
CudaGrpah can be enabled by setting `--use-cudagraph` or `--graph-optimization-config '{"use_cudagraph":true}'`. Using two different methods to set the use graph simultaneously may cause conflicts.
The `graph_opt_level` parameter within `--graph-optimization-config` is used to configure the graph optimization level, with the following available options:
- `0`: Use Dynamic compute graph, default to 0
- `1`: Use Static compute graph, during the initialization phase, Paddle API will be used to convert the dynamic image into a static image
- `2`: Base on Static compute graph, use the complier(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize
In general, static graphs have lower Kernel Launch overhead than dynamic graphs, and it is recommended to use static graphs.
For adapted models, FastDeploy's CudaGraph * * can support both dynamic and static graphs * * simultaneously.
For adapted models, FastDeploy's CudaGraph *can support both dynamic and static graphs* simultaneously.
When CudaGraph is enabled in the default configuration, a list of Batch Sizes that CudaGraph needs to capture will be automatically set based on the 'max_num_deqs' parameter. The logic for generating the list of Batch Sizes that need to be captured is as follows
1. Generate a candidate list with a range of [1,1024] Batch Size.
```
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
@@ -100,24 +99,25 @@ When CudaGraph is enabled in the default configuration, a list of Batch Sizes th
# Batch Size (256, 288, ... 992, 1024]
candidate_capture_sizes += [32 * i for i in range(17, 33)]
```
2. Crop the candidate list based on the user set 'max_num_deqs' to obtain a CudaGraph capture list with a range of [1,' max_num_deqs'].
Users can also customize the batch size list that needs to be captured by CudaGraph through the parameter `cudagraph_capture_sizes` in`--graph-optimization-config`:
```
--graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}'
```
### CudaGraph related parameters
Using CudaGraph incurs some additional memory overhead, divided into two categories in FastDeploy:
* Additional input Buffer overhead
* CudaGraph uses dedicated memory pool, thus holding some intermediate activation memory isolated from main framework
- Additional input Buffer overhead
- CudaGraph uses dedicated memory pool, thus holding some intermediate activation memory isolated from main framework
FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter to calculate available memory for `KVCache`, after initializing `KVCache` then uses remaining memory to initialize CudaGraph. Since CudaGraph is not enabled by default currently, using default startup parameters may encounter `Out of memory` errors, can try following solutions:
* Lower `gpu_memory_utilization` value, reserve more memory for CudaGraph.
* Lower `max_num_seqs` to decrease the maximum concurrency.
* Customize the batch size list that CudaGraph needs to capture through `graph_optimization_config`, and reduce the number of captured graphs by using `cudagraph_capture_sizes`
- Lower `gpu_memory_utilization` value, reserve more memory for CudaGraph.
- Lower `max_num_seqs` to decrease the maximum concurrency.
- Customize the batch size list that CudaGraph needs to capture through `graph_optimization_config`, and reduce the number of captured graphs by using `cudagraph_capture_sizes`
- Before use, must ensure loaded model is properly decorated with ```@support_graph_optimization```.
@@ -148,5 +148,6 @@ FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter
class Ernie45TModel(nn.Layer): # Note decorator is added to nn.Layer subclass
...
```
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.

View File

@@ -25,13 +25,10 @@
多实例情况下每收到一条请求需要根据不同的策略将请求分配到不同的Prefill实例和Decode实例。通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。
## 使用说明
### 单机分离式部署
#### 在线推理服务
使用如下命令进行服务部署
@@ -75,9 +72,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
### 多机分离式部署
#### 前置依赖 Redis
- 使用`conda`安装
* 使用`conda`安装
```bash
# 安装
conda install redis
@@ -85,7 +82,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 &
```
- 使用`apt`安装
* 使用`apt`安装
```bash
# 安装
sudo apt install redis-server -y
@@ -93,7 +91,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server
```
- 使用`yum`安装
* 使用`yum`安装
```bash
# 安装
sudo yum install redis -y

View File

@@ -23,6 +23,7 @@
### 前置依赖 Redis
- 使用`conda`安装
```bash
# 安装
conda install redis
@@ -31,6 +32,7 @@ nohup redis-server > redis.log 2>&1 &
```
- 使用`apt`安装
```bash
# 安装
sudo apt install redis-server -y
@@ -39,6 +41,7 @@ sudo systemctl start redis-server
```
- 使用`yum`安装
```bash
# 安装
sudo yum install redis -y
@@ -47,6 +50,7 @@ sudo systemctl start redis
```
### 启动FastDeploy
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \
@@ -62,6 +66,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \
--scheduler-load-shards-num 1
```
[启动参数说明](../online_serving/scheduler.md)
可以将上述启动命令在多个机器执行,启动多个推理实例(如果是在一个机器中启动多个推理实例,注意端口不要冲突)。

View File

@@ -8,7 +8,6 @@ Prefix Caching前缀缓存是一种优化生成式模型推理效率的技
增量计算:对于后续请求,只需计算新增部分(如用户追加的输入)并复用缓存的中间结果,显著减少计算量。
## 服务化部署开启 Prefix Caching
启动服务增加以下参数 `enable-prefix-caching`默认只开启一级缓存GPU 缓存)。

View File

@@ -17,10 +17,10 @@
同时在思考模型中,支持通过```reasoning_max_tokens```控制思考内容的长度,在请求中添加```metadata={"reasoning_max_tokens": 1024}```即可。
### 快速使用
## 快速使用
在启动模型服务时, 通过`--reasoning-parser`参数指定解析器名称.
该解析器会解析思考模型的输出, 提取`reasoning_content`字段.
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \
@@ -30,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--reasoning-parser ernie-45-vl
```
接下来, 向模型发送 `chat completion` 请求
```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \
@@ -45,10 +47,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
}'
```
字段`reasoning_content`包含得出最终结论的思考步骤,而`content`字段包含最终结论。
### 流式会话
在流式会话中, `reasoning_content`字段会可以在`chat completion response chunks`中的 `delta` 中获取
```python
from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server.
@@ -73,4 +77,3 @@ for chunk in chat_response:
print("\n")
```

View File

@@ -50,6 +50,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}'
```
### PD 分离式部署1P1D
> 在8×H100上部署1P1DP、D节点 分别使用 4×H100量化方式选择 WINT4
> 与常规 PD 分离部署一致,仅需替换配置文件并新增 speculative_config
@@ -57,6 +58,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
- P 节点Prefill
> 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-prefill.yaml`
```
export FD_LOG_DIR="log_prefill"
rm -rf ${FD_LOG_DIR}
@@ -80,9 +82,11 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-password "scheduler_mtp" \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' &
```
- D 节点Decode
> 配置文件: `benchmarks/yaml/eb45t-32k-wint4-mtp-tp4-decode.yaml`
```
export FD_LOG_DIR="log_prefill"
rm -rf ${FD_LOG_DIR}
@@ -111,6 +115,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。
> 使用 4×H100量化方式选择 WINT4
> 配置文件benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml
```
python -m fastdeploy.entrypoints.openai.api_server \
--model ${path_to_main_model} \

View File

@@ -131,4 +131,3 @@ python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parall
```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
```

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
```
## 2. 启动服务
```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8
```
#### 请求服务
### 请求服务
您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。

View File

@@ -18,13 +18,15 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
## 准备容器
1. 启动容器
```bash
docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker exec -it paddle_infer bash
```
/home/paddle 为模型文件、whl包、脚本所在目录
2. 安装whl包
1. 安装whl包
```bash
pip3 install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
@@ -37,6 +39,7 @@ pip3 install fastdeploy_iluvatar_gpu -i https://www.paddlepaddle.org.cn/packages
脚本内容如下
`run_demo.sh`:
```bash
#!/bin/bash
export PADDLE_XCCL_BACKEND=iluvatar_gpu
@@ -48,7 +51,6 @@ python3 run_demo.py
run_demo.py
```python
from fastdeploy import LLM, SamplingParams
@@ -75,10 +77,13 @@ for output in outputs:
## 运行demo
执行
```bash
./run_demo.sh
```
会有如下 log 打印load 模型耗时约74sdemo 运行约240s。
```
/usr/local/lib/python3.10/site-packages/paddle/utils/cpp_extension/extension_utils.py:715: UserWarning: No ccache found. Please be aware that recompiling all source files may be required. You can download and install ccache from: https://github.com/ccache/ccache/blob/master/doc/INSTALL.md
warnings.warn(warning_message)

View File

@@ -21,6 +21,7 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12
## 2. 预编译Pip安装
首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
``` shell
python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
@@ -28,6 +29,7 @@ python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn
再安装 fastdeploy**注意不要通过pypi源安装**,需要通过如下方式安装
如你的 GPU 是 SM80/90 架构(A100/H100等),按如下方式安装
```
# 安装稳定版本fastdeploy
python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
@@ -37,6 +39,7 @@ python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages
```
如你的 GPU 是 SM86/89 架构(4090/L20/L40等),按如下方式安装
```
# 安装稳定版本fastdeploy
python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
@@ -59,11 +62,13 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu .
## 4. Wheel包源码编译
首先安装 paddlepaddle-gpu详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/)
``` shell
python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
接着克隆源代码,编译安装
``` shell
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
@@ -74,11 +79,13 @@ cd FastDeploy
# 第4个参数: 编译的GPU架构
bash build.sh 1 python false [80,90]
```
编译后的产物在```FastDeploy/dist```目录下。
## 环境检查
在安装 FastDeploy 后,通过如下 Python 代码检查环境的可用性
``` python
import paddle
from paddle.jit.marker import unified
@@ -87,4 +94,5 @@ paddle.utils.run_check()
# 检查FastDeploy自定义算子编译成功与否
from fastdeploy.model_executor.ops.gpu import beam_search_softmax
```
如上代码执行成功,则认为环境可用。

View File

@@ -15,6 +15,7 @@
## 1. 启动服务
安装FastDeploy后在终端执行如下命令启动服务其中启动命令配置方式参考[参数说明](../parameters.md)
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \
@@ -24,6 +25,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--max-model-len 32768 \
--max-num-seqs 32
```
>💡 注意:在 ```--model``` 指定的路径中,若当前目录下不存在该路径对应的子目录,则会尝试根据指定的模型名称(如 ```baidu/ERNIE-4.5-0.3B-Paddle```查询AIStudio是否存在预置模型若存在则自动启动下载。默认的下载路径为```~/xx```。关于模型自动下载的说明和配置参阅[模型下载](../supported_models.md)。
```--max-model-len``` 表示当前部署的服务所支持的最长Token数量。
```--max-num-seqs``` 表示当前部署的服务所支持的最大并发处理数量。
@@ -36,6 +38,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
## 2. 用户发起服务请求
执行启动服务指令后,当终端打印如下信息,说明服务已经启动成功。
```
api_server.py[line:91] Launching metrics service at http://0.0.0.0:8181/metrics
api_server.py[line:94] Launching chat completion service at http://0.0.0.0:8180/v1/chat/completions
@@ -47,11 +50,13 @@ INFO: Uvicorn running on http://0.0.0.0:8180 (Press CTRL+C to quit)
```
FastDeploy提供服务探活接口用以判断服务的启动状态执行如下命令返回 ```HTTP/1.1 200 OK``` 即表示服务启动成功。
```shell
curl -i http://0.0.0.0:8180/health
```
通过如下命令发起服务请求
```shell
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
-H "Content-Type: application/json" \

View File

@@ -24,6 +24,7 @@
## 文档说明
本项目文档基于mkdocs支持编译可视化查看参考如下命令进行编译预览
```
pip install requirements.txt
@@ -32,4 +33,5 @@ mkdocs build
mkdocs serve
```
根据提示打开相应地址即可。

View File

@@ -19,7 +19,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
--enable-logprob
```
服务部署时的命令行更多使用方式参考[参数说明](../parameters.md)。
## 发送用户请求
@@ -51,6 +50,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
```
使用 Python 脚本发送用户请求示例如下:
```python
import openai
host = "0.0.0.0"
@@ -103,6 +103,7 @@ FastDeploy 增加的返回字段如下:
- `reasoning_content`: 思考链的返回结果
返回参数总览
```python
ChatCompletionStreamResponse:
id: str

View File

@@ -18,7 +18,6 @@ FastDeploy 目前支持两种调度器: **本地调度器** 和 **全局调度
通过角色分离prefill 节点负责接收并处理请求decode节点完成后续生成可以更细粒度地控制资源分配、提高吞吐量与 GPU 利用率。
## 配置参数
| 字段名 | 字段类型 | 是否必填 | 默认值 | 生效范围 | 说明 |
| ------------------------------------ | -------- | -------- | --------- |------------------------|-----------------------------------|

View File

@@ -2,7 +2,6 @@
在使用FastDeploy部署模型包括离线推理、服务化部署涉及如下参数配置其实需要注意在使用离线推理时各参数配置即为如下参数名而在使用命令行启动服务时相应参数中的分隔符需要从```_```修改为```-```,如```max_model_len```在命令行中则为```--max-model-len```。
| 参数名 | 类型 | 说明 |
|:-----------------------------------|:----------| :----- |
| ```port``` | `int` | 仅服务化部署需配置服务HTTP请求端口号默认8000 |
@@ -44,7 +43,6 @@
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache块```和```模型计算中间激活值```占用。其中预分配KVCache块由```num_gpu_blocks_override```决定,其单位为```block_size```(默认64即一个块可以存储64个Token的KVCache。
@@ -88,6 +86,7 @@ FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache
在默认配置下开启 CudaGraph 时,会根据 `max_num_seqs` 参数自动设置 CudaGraph 需要捕获的 Batch Size 列表,需要捕获的 Batch Size 的列表自动生成逻辑如下:
1. 生成一个范围为 [1,1024] Batch Size 的候选列表
```
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
@@ -96,24 +95,24 @@ FastDeploy在推理过程中显存被```模型权重```、```预分配KVCache
# Batch Size (256, 288, ... 992, 1024]
candidate_capture_sizes += [32 * i for i in range(17, 33)]
```
2. 根据用户设置的 `max_num_seqs` 裁剪候选列表,得到范围为 [1, `max_num_seqs`] 的 CudaGraph 捕获列表。
用户也可以通过 `--graph-optimization-config` 中的 `cudagraph_capture_sizes` 参数自定义需要被 CudaGraph 捕获的 Batch Size 列表:
```
--graph-optimization-config '{"cudagraph_capture_sizes": [1, 3, 5, 7, 9]}'
```
### CudaGraph相关参数说明
使用 CudaGraph 会产生一些额外的显存开销在FastDeploy中分为下面两类
* 额外的输入 Buffer 开销
* CudaGraph 使用了专用的显存池,因此会持有一部分与主框架隔离的中间激活显存
- 额外的输入 Buffer 开销
- CudaGraph 使用了专用的显存池,因此会持有一部分与主框架隔离的中间激活显存
FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算 `KVCache` 可用的显存,初始化完 `KVCache` 之后才会使用剩余显存初始化 CudaGraph。由于 CudaGraph 目前还不是默认开启的,因此使用默认启动参数可能会遇到 `Out Of Memory` 错误,可以尝试使用下面三种方式解决:
* 调低 `gpu_memory_utilization` 的值多预留一些显存给CudaGraph使用。
* 调低 `max_num_seqs` 的值,降低最大并发数。
* 通过 `graph_optimization_config` 自定义需要 CudaGraph 捕获的 Batch Size 列表 `cudagraph_capture_sizes`,减少捕获的图的数量
- 调低 `gpu_memory_utilization` 的值多预留一些显存给CudaGraph使用。
- 调低 `max_num_seqs` 的值,降低最大并发数。
- 通过 `graph_optimization_config` 自定义需要 CudaGraph 捕获的 Batch Size 列表 `cudagraph_capture_sizes`,减少捕获的图的数量
使用CudaGraph之前需要确保加载的模型被装饰器 ```@support_graph_optimization```正确修饰。
@@ -144,5 +143,6 @@ FastDeploy 的初始化顺序为先使用 `gpu_memory_utilization` 参数计算
class Ernie45TModel(nn.Layer): # 注意 decorator 加在 nn.Layer 的子类上
...
```
- 当开启 ```use_cudagraph``` 时,暂时只支持单卡推理,即 ```tensor_parallel_size``` 设为1。
- 当开启 ```use_cudagraph``` 时,暂不支持开启 ```enable_prefix_caching``` 或 ```enable_chunked_prefill``` 。

View File

@@ -44,4 +44,3 @@ FastDeploy 按以下格式命名各种量化精度:
- **WNF4A8C8**NF4指4bit norm-float数值类型
- **Wfp8Afp8**权重和激活均为FP8精度
- **W4Afp8**权重为INT4, 激活为FP8

View File

@@ -52,6 +52,3 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 通过设置 `--quantization``block_wise_fp8` 选择在线 Block-wise FP8 量化。
- 部署 ERNIE-4.5-300B-A47B-Paddle Block-wise FP8 最少需要 80G * 8卡。
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)

View File

@@ -48,7 +48,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 更多部署教程请参考[get_started](../get_started/ernie-4.5.md)
- 更多模型说明请参考[支持模型列表](../supported_models.md)。
## WINT2效果
在ERNIE-4.5-300B-A47B模型上WINT2与WINT4效果对比

View File

@@ -22,4 +22,3 @@
- ```splitwise```: 分离式部署相关模块
- ```scripts```/```tools```FastDeploy 用于执行功能的辅助脚本,比如编译,单测执行,代码风格纠正等
- ```test```:项目单测验证使用到的代码

View File

@@ -19,14 +19,12 @@ FastDeploy 在部署过程中,会产生如下日志文件,各日志含义说
## 在线推理客户端日志
* `api_server.log` : 记录启动参数,及接收到的请求信息
## 调度器日志
* `scheduler.log` : 记录调度器的信息包含当前结点的信息,每条请求分配的信息
## 投机解码日志
* `speculate.log` : 投机解码相关信息
## Prefix Caching 相关日志
* `cache_queue_manager.log` : 记录启动参数,及接收到的请求信息

View File

@@ -22,14 +22,14 @@ import sys
os.environ["GLOG_minloglevel"] = "2"
# suppress log from aistudio
os.environ["AISTUDIO_LOG"] = "critical"
from fastdeploy.utils import version
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
__all__ = ['LLM', 'SamplingParams']
__all__ = ["LLM", "SamplingParams"]
try:
import use_triton_in_paddle
use_triton_in_paddle.make_triton_compatible_with_paddle()
except ImportError:
pass
@@ -38,13 +38,21 @@ except ImportError:
def _patch_fastsafetensors():
try:
file_path = subprocess.check_output([
sys.executable, "-c", "import fastsafetensors, os; \
file_path = (
subprocess.check_output(
[
sys.executable,
"-c",
"import fastsafetensors, os; \
print(os.path.join(os.path.dirname(fastsafetensors.__file__), \
'frameworks', '_paddle.py'))"
]).decode().strip()
'frameworks', '_paddle.py'))",
]
)
.decode()
.strip()
)
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
content = f.read()
if "DType.U16: DType.BF16," in content and "DType.U8: paddle.uint8," in content:
return
@@ -56,21 +64,20 @@ def _patch_fastsafetensors():
inside_block = False
for line in lines:
new_lines.append(line)
if 'need_workaround_dtypes: Dict[DType, DType] = {' in line:
if "need_workaround_dtypes: Dict[DType, DType] = {" in line:
inside_block = True
elif inside_block and '}' in line:
new_lines.insert(-1, ' DType.U16: DType.BF16,')
elif inside_block and "}" in line:
new_lines.insert(-1, " DType.U16: DType.BF16,")
inside_block = False
modified = True
content = "\n".join(new_lines)
if "DType.I8: paddle.uint8," in content:
content = content.replace("DType.I8: paddle.uint8,",
"DType.U8: paddle.uint8,")
content = content.replace("DType.I8: paddle.uint8,", "DType.U8: paddle.uint8,")
modified = True
if modified:
with open(file_path, 'w') as f:
with open(file_path, "w") as f:
f.write(content + "\n")
except Exception as e:

View File

@@ -109,13 +109,12 @@ class BlockNode:
parent_node_id = None
return (
f"node_id {self.node_id}: depth {self.depth} hash_value {self.hash_value}"
+
f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}"
+
f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} "
+ f"has_in_gpu {self.has_in_gpu} " +
f"cache_status {self.cache_status} parent {parent_node_id} with children number "
+ f"{len(self.children)} req_id_set {self.req_id_set}")
+ f" shared_count {self.shared_count} is_gpu_leaf_node {self.is_gpu_leaf_node}"
+ f" is_cpu_leaf_node {self.is_cpu_leaf_node} block_id {self.block_id} "
+ f"has_in_gpu {self.has_in_gpu} "
+ f"cache_status {self.cache_status} parent {parent_node_id} with children number "
+ f"{len(self.children)} req_id_set {self.req_id_set}"
)
@property
def has_in_gpu(self):
@@ -141,8 +140,7 @@ class BlockNode:
"""
check if the node is a leaf node in CPU
"""
if (self.cache_status == CacheStatus.CPU) and (len(self.children)
== 0):
if (self.cache_status == CacheStatus.CPU) and (len(self.children) == 0):
return True
return False

View File

@@ -21,20 +21,20 @@ import time
import numpy as np
import paddle
from fastdeploy.cache_manager.transfer_factory import (IPCCommManager,
RDMACommManager)
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
class CacheMessager(object):
class CacheMessager:
"""
CacheMessager is used to send the cache data between the engine worker and the cache server.
"""
def __init__(self,
def __init__(
self,
splitwise_role,
transfer_protocol,
pod_ip,
@@ -45,7 +45,8 @@ class CacheMessager(object):
nranks,
num_layers,
gpu_id=0,
rdma_port=None):
rdma_port=None,
):
"""
Initialize the CacheMessager object.
@@ -64,8 +65,10 @@ class CacheMessager(object):
None
"""
assert splitwise_role in ["prefill", "decode"], \
"splitwise_role must be prefill or decode"
assert splitwise_role in [
"prefill",
"decode",
], "splitwise_role must be prefill or decode"
self.splitwise_role = splitwise_role
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
@@ -76,11 +79,11 @@ class CacheMessager(object):
is_server=False,
num_client=self.nranks,
client_id=self.rank,
local_data_parallel_id=local_data_parallel_id)
local_data_parallel_id=local_data_parallel_id,
)
transfer_protocol = transfer_protocol.split(",")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}"
f"rank: {rank}")
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
@@ -90,10 +93,8 @@ class CacheMessager(object):
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
key_cache = self.gpu_cache_kvs[
f'key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
val_cache = self.gpu_cache_kvs[
f'value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}']
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
@@ -109,7 +110,8 @@ class CacheMessager(object):
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}")
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
)
self.block_bytes = block_bytes
# 3. initialize the messager
@@ -122,24 +124,26 @@ class CacheMessager(object):
cache_v,
)
local_device_id = int(str(cache_k[0].place)[-2])
logger.info(
f"done create ipc_comm with local_device_id:{local_device_id}, "
)
logger.info(f"done create ipc_comm with local_device_id:{local_device_id}, ")
elif protocol == "rdma":
logger.info(
f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}"
)
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
self.messager[protocol] = RDMACommManager(
splitwise_role, rank, gpu_id, cache_k_ptr_list,
cache_v_ptr_list, max_block_num, block_bytes, rdma_port)
splitwise_role,
rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
)
self.gpu_id = gpu_id
self.cache_info = dict()
layerwise_send_cache_thread = threading.Thread(
target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
@@ -159,26 +163,30 @@ class CacheMessager(object):
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True)
create=True,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True)
create=True,
)
except:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.rank}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False)
create=False,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.rank}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False)
create=False,
)
step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1
@@ -193,21 +201,19 @@ class CacheMessager(object):
if cache_info:
logger.debug(f"cache info {cache_info}")
for info in cache_info:
if info['request_id'] in self.cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
current_info = self.cache_info[info["request_id"]]
if "dest_block_ids" in current_info and "src_block_ids" in current_info:
current_src_blocks = current_info[
"src_block_ids"][-len(current_info["dest_block_ids"]):]
current_info[
"src_block_ids"] = current_src_blocks
current_src_blocks = current_info["src_block_ids"][
-len(current_info["dest_block_ids"]) :
]
current_info["src_block_ids"] = current_src_blocks
current_info["current_layer_ids"] = 0
current_info["status"] = "init"
logger.info(
f"start cache_infos: {current_info}")
logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(
self.last_step_idx, current_info['current_id'])
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
@@ -223,64 +229,53 @@ class CacheMessager(object):
if not self.cache_info:
time.sleep(0.001)
continue
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}"
)
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
if "layer_idx" not in item:
item["layer_idx"] = 0
if item['status'] == 'error':
if item["status"] == "error":
del self.cache_info[req_id]
continue
if item['current_id'] > prefilled_step_idx:
if item["current_id"] > prefilled_step_idx:
continue
current_transfer_protocol = item["transfer_protocol"]
if item["transfer_protocol"] == "rdma":
target_ip = item['ip']
target_id = int(item['rdma_ports'][self.rank])
status = self.messager[
current_transfer_protocol].connect(
target_ip, target_id)
target_ip = item["ip"]
target_id = int(item["rdma_ports"][self.rank])
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(
f"connect to {target_ip}:{target_id} failed")
logger.error(f"connect to {target_ip}:{target_id} failed")
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "connect error")
])
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
continue
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item['device_ids'][self.rank])
src_block_ids = paddle.to_tensor(item['src_block_ids'],
dtype='int32',
place='cpu')
dest_block_ids = paddle.to_tensor(item['dest_block_ids'],
dtype='int32',
place='cpu')
if item['current_id'] < prefilled_step_idx:
target_id = int(item["device_ids"][self.rank])
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
if item["current_id"] < prefilled_step_idx:
current_layer_idx = self.num_layers
else:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"],
current_layer_idx):
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[
current_transfer_protocol].write_cache(
target_ip, target_id, src_block_ids,
dest_block_ids, layer_idx)
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0:
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "write cache error")
])
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
@@ -298,16 +293,14 @@ class CacheMessager(object):
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item['layer_idx'] = current_layer_idx
if item['layer_idx'] == self.num_layers:
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([
(item['request_id'], "finished")
])
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id]
@@ -315,5 +308,4 @@ class CacheMessager(object):
self.last_layer_idx = prefilled_layer_idx
except Exception as e:
logger.error(
f"prefill layerwise send cache thread has exception: {e}")
logger.error(f"prefill layerwise send cache thread has exception: {e}")

View File

@@ -14,18 +14,16 @@
# limitations under the License.
"""
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
class CacheMetrics:
"""
Cache Metrics used to record the cache hit time, token num, request num, etc.
"""
def __init__(self):
self.total_match_time = 0.0
self.avg_match_time = 0.0
@@ -47,19 +45,14 @@ class CacheMetrics:
self.cpu_hit_token_ratio = 0.0
self.gpu_hit_token_ratio = 0.0
def _update_history_hit_metrics(self):
"""
update hit ratio
"""
self.hit_req_ratio = self.hit_req_count / self.req_count
self.hit_token_ratio = self.matched_token_num / self.total_token_num
self.cpu_hit_token_ratio = (
self.total_cpu_matched_token_num / self.total_token_num
)
self.gpu_hit_token_ratio = (
self.total_gpu_matched_token_num / self.total_token_num
)
self.cpu_hit_token_ratio = self.total_cpu_matched_token_num / self.total_token_num
self.gpu_hit_token_ratio = self.total_gpu_matched_token_num / self.total_token_num
logger.info(
f"Metrics for all requests: req_count {self.req_count} hit_req_count {self.hit_req_count}"
@@ -83,29 +76,15 @@ class CacheMetrics:
calculate hit metrics for current query
"""
cpu_cache_match_ratio = (
current_query_cpu_match_token_num / current_query_token_num
)
gpu_cache_match_ratio = (
current_query_gpu_match_token_num / current_query_token_num
)
cpu_cache_match_ratio = current_query_cpu_match_token_num / current_query_token_num
gpu_cache_match_ratio = current_query_gpu_match_token_num / current_query_token_num
total_match_ratio = (
cpu_cache_match_ratio + gpu_cache_match_ratio
)
total_match_ratio = cpu_cache_match_ratio + gpu_cache_match_ratio
self.total_cpu_matched_token_num += current_query_cpu_match_token_num
self.total_gpu_matched_token_num += current_query_gpu_match_token_num
self.total_cpu_matched_token_num += (
current_query_cpu_match_token_num
)
self.total_gpu_matched_token_num += (
current_query_gpu_match_token_num
)
self.matched_token_num += (
current_query_cpu_match_token_num
+ current_query_gpu_match_token_num
)
self.matched_token_num += current_query_cpu_match_token_num + current_query_gpu_match_token_num
self.total_token_num += current_query_token_num
logger.info(
f"Metrics for req_id {req_id}: token_num {current_query_token_num}"

View File

@@ -26,8 +26,11 @@ import paddle
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (cuda_host_alloc, set_data_ipc,
swap_cache_all_layers)
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
set_data_ipc,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
@@ -36,79 +39,58 @@ def parse_args():
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument("--splitwise_role",
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed")
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers",
type=int,
default=1,
help="model num layers")
parser.add_argument("--head_dim",
type=int,
default=1,
help="model head dim")
parser.add_argument("--kv_num_head",
type=int,
default=1,
help="model kv num head")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num",
type=int,
default=1,
help="number of model parallel")
parser.add_argument("--protocol",
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now")
parser.add_argument("--enable_splitwise",
type=int,
default=0,
help="enable splitwise ")
parser.add_argument("--cache_queue_port",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="cache queue port")
parser.add_argument("--pod_ip",
type=str,
default="0.0.0.0",
help="pod ip")
parser.add_argument("--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port")
parser.add_argument("--engine_pid",
type=str,
default=None,
help="engine pid")
help="engine worker queue port",
)
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument("--num_gpu_blocks",
type=int,
default=1,
help="gpu cache block number")
parser.add_argument("--num_cpu_blocks",
type=int,
default=4,
help="cpu cache block number")
parser.add_argument("--block_size",
type=int,
default=64,
help="cache block size(tokens)")
parser.add_argument("--bytes_per_layer_per_block",
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--bytes_per_layer_per_block",
type=int,
default=1024,
help="per layer per block bytes")
parser.add_argument("--cache_dtype",
help="per layer per block bytes",
)
parser.add_argument(
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype")
parser.add_argument("--speculative_config",
help="cache dtype",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config")
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
@@ -134,14 +116,10 @@ class CacheTransferManager:
self.gpu_cache_v_tensors = []
self.speculative_config = SpeculativeConfig(**args.speculative_config)
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = \
int(args.num_gpu_blocks * \
self.speculative_config.num_gpu_block_expand_ratio)
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=1)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
@@ -154,17 +132,16 @@ class CacheTransferManager:
is_server=False,
num_client=args.mp_num,
client_id=rank,
local_data_parallel_id=args.local_data_parallel_id)
local_data_parallel_id=args.local_data_parallel_id,
)
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else \
self.num_extra_layer_gpu_blocks
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)] = paddle.full(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
@@ -174,11 +151,8 @@ class CacheTransferManager:
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)])
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)] = paddle.full(
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
@@ -188,47 +162,42 @@ class CacheTransferManager:
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
self.gpu_cache_kvs["key_caches_{}_rank{}_device{}".format(
i, rank, device)],
"key_caches_{}_rank{}.device{}".format(i, rank, device))
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs["value_caches_{}_rank{}_device{}".format(
i, rank, device)],
"value_caches_{}_rank{}.device{}".format(i, rank, device))
cache_kv_size_byte = sum(
[tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(
f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
self.cpu_cache_kvs["key_caches_{}_rank{}".format(
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
args.bytes_per_layer_per_block)
self.k_dst_ptrs.append(
self.cpu_cache_kvs["key_caches_{}_rank{}".format(i, rank)])
self.cpu_cache_kvs["value_caches_{}_rank{}".format(
i, rank)] = cuda_host_alloc(args.num_cpu_blocks *
args.bytes_per_layer_per_block)
self.v_dst_ptrs.append(
self.cpu_cache_kvs["value_caches_{}_rank{}".format(i, rank)])
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(name="cache_ready_signal",
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False)
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
@@ -251,9 +220,7 @@ class CacheTransferManager:
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(
f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}"
)
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
@@ -261,10 +228,17 @@ class CacheTransferManager:
array=cache_task_broadcast_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False)
create=False,
)
def _do_swap_to_cpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
event_type, transfer_task_id):
def _do_swap_to_cpu_task(
self,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
):
"""
swap cache GPU->CPU
"""
@@ -282,14 +256,17 @@ class CacheTransferManager:
if self.rank == 0:
self.cache_task_queue.swap_to_cpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(
f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
logger.info(
f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
logger.debug(f"_do_swap_to_cpu_task: put_transfer_done_signal {result}")
logger.info(f"_do_swap_to_cpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
def _do_swap_to_gpu_task(self, swap_node_ids, gpu_block_id, cpu_block_id,
event_type, transfer_task_id):
def _do_swap_to_gpu_task(
self,
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
):
"""
swap cache CPU->GPU
"""
@@ -307,11 +284,8 @@ class CacheTransferManager:
if self.rank == 0:
self.cache_task_queue.swap_to_gpu_barrier2.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(
f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
logger.info(
f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}"
)
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
def do_data_transfer(self):
"""
@@ -327,8 +301,7 @@ class CacheTransferManager:
if self.rank == 0:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
data, read_finish = self.cache_task_queue.get_transfer_task(
)
data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"transfer data: get_transfer_task {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
@@ -386,8 +359,7 @@ class CacheTransferManager:
"""
logger.debug(
f"transfer data: transfer_task_id {transfer_task_id}: swap_node_ids {swap_node_ids}"
+
f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
+ f"task_gpu_block_id {task_gpu_block_id} task_cpu_block_id {task_cpu_block_id} event_type {event_type}"
)
start_time = time.time()
try:
@@ -446,8 +418,7 @@ class CacheTransferManager:
elasped_time = end_time - start_time
logger.info(
f"transfer data: transfer_task_id {transfer_task_id} event_type {event_type}: "
+
f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
)
return (
swap_node_ids,

View File

@@ -41,11 +41,13 @@ class PrefixCacheManager:
PrefixCacheManager is used to manage the prefix tree and the cache.
"""
def __init__(self,
def __init__(
self,
config,
tensor_parallel_size,
splitwise_role="mixed",
local_data_parallel_id=0):
local_data_parallel_id=0,
):
"""
initialize the PrefixCacheManager
"""
@@ -66,14 +68,12 @@ class PrefixCacheManager:
self.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1))
if self.num_cpu_blocks > 0:
self.cpu_free_block_list = list(
range(self.num_cpu_blocks - 1, -1, -1))
self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1))
else:
self.cpu_free_block_list = []
heapq.heapify(self.gpu_free_block_list)
heapq.heapify(self.cpu_free_block_list)
self.node_id_pool = list(
range(self.num_gpu_blocks + self.num_cpu_blocks))
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
@@ -90,7 +90,7 @@ class PrefixCacheManager:
self.task_swapping_event = {}
self.node_map = {}
self.req_leaf_map = ({}) # {request_id: leaf node}
self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list)
@@ -102,14 +102,18 @@ class PrefixCacheManager:
logger.info(
f"num_gpu_blocks_server_owned {self.num_gpu_blocks} num_cpu_blocks "
+
f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
)
def launch_cache_manager(self, cache_config, tensor_parallel_size, \
device_ids, pod_ip, engine_worker_queue_port, pid_suffix):
def launch_cache_manager(
self,
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
):
"""
launch_cache_manager function used to initialize the cache manager.
"""
@@ -120,70 +124,72 @@ class PrefixCacheManager:
array=broadcast_cache_task_flag_array,
dtype=np.int32,
suffix=pid_suffix,
create=True)
create=True,
)
self.cache_task_queue = EngineCacheQueue(
address=(pod_ip, cache_config.cache_queue_port),
authkey=b'cache_queue_service',
authkey=b"cache_queue_service",
is_server=False,
num_client=tensor_parallel_size,
client_id=0,
local_data_parallel_id=self.local_data_parallel_id)
local_data_parallel_id=self.local_data_parallel_id,
)
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename)
if (hasattr(cache_config.model_cfg, "num_key_value_heads")
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads
) // tensor_parallel_size
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size],
dtype=np.int32)
self.cache_ready_signal = IPCSignal(name="cache_ready_signal",
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True)
create=True,
)
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" +
f" {sys.executable} {py_path}" +
f" --device_id {int(device_ids[i])}" + f" --rank {i}" +
f" --splitwise_role {self.splitwise_role}" +
f" --num_layers {cache_config.model_cfg.num_layers}" +
f" --head_dim {cache_config.model_cfg.head_dim}" +
f" --kv_num_head {kv_num_head}" +
f" --mp_num {tensor_parallel_size}" +
f" --cache_dtype {cache_config.cache_dtype}" +
f" --cache_queue_port {cache_config.cache_queue_port}" +
f" --enable_splitwise {int(self.enable_splitwise)}" +
f" --pod_ip {pod_ip}" +
f" --engine_worker_queue_port {engine_worker_queue_port}" +
f" --num_gpu_blocks {cache_config.total_block_num}" +
f" --num_cpu_blocks {cache_config.num_cpu_blocks}" +
f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}" +
f" --engine_pid {pid_suffix}" +
f" --protocol {cache_config.cache_transfer_protocol}" +
f" --local_data_parallel_id {self.local_data_parallel_id}" +
f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+
f" --speculative_config '{self.speculative_config.to_json_string()}'"
+
f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(
subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
@@ -192,9 +198,7 @@ class PrefixCacheManager:
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_manager.log for more information"
)
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
@@ -207,12 +211,10 @@ class PrefixCacheManager:
"""
self.cache_config = cache_config
self.num_gpu_blocks = cache_config.prefill_kvcache_block_num
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1,
-1)) # 服务端管理的GPU上剩余的block id
self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) # 服务端管理的GPU上剩余的block id
heapq.heapify(self.gpu_free_block_list)
self.node_id_pool = list(
range(self.num_gpu_blocks + self.num_cpu_blocks))
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
def _enable_cpu_cache(self):
"""
@@ -226,8 +228,7 @@ class PrefixCacheManager:
# port=ipc_cache_queue_port,
# )
# 开启获取传输任务结果的监听线程
self.transfer_recv_thread = threading.Thread(
target=self.recv_data_transfer_result)
self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result)
self.transfer_recv_thread.start()
def allocate_gpu_blocks(self, num_blocks):
@@ -237,9 +238,7 @@ class PrefixCacheManager:
assert num_blocks <= len(
self.gpu_free_block_list
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [
heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)
]
allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
@@ -265,9 +264,7 @@ class PrefixCacheManager:
assert num_blocks <= len(
self.cpu_free_block_list
), f"cpu free block num: {len(self.cpu_free_block_list)} < needed number {num_blocks}"
allocated_block_ids = [
heapq.heappop(self.cpu_free_block_list) for i in range(num_blocks)
]
allocated_block_ids = [heapq.heappop(self.cpu_free_block_list) for i in range(num_blocks)]
logger.info(
f"allocate_cpu_blocks: {allocated_block_ids}, len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
)
@@ -307,16 +304,17 @@ class PrefixCacheManager:
"""
self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task((
self.cache_task_queue.put_transfer_task(
(
swap_node_ids,
gpu_block_ids,
cpu_block_ids,
event_type,
transfer_task_id,
))
)
)
if is_sync:
self.sync_swap_task(transfer_task_id)
return
def sync_swap_task(self, transfer_task_id):
"""
@@ -325,26 +323,27 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id].wait()
del self.task_swapping_event[transfer_task_id]
def _check_validity(self, req_id, match_gpu_blocks_num,
expected_block_num):
def _check_validity(self, req_id, match_gpu_blocks_num, expected_block_num):
"""
check enough gpu memory to allocate cache
"""
if expected_block_num - match_gpu_blocks_num > len(
self.gpu_free_block_list):
if expected_block_num - match_gpu_blocks_num > len(self.gpu_free_block_list):
msg = (
f"request_block_ids: request block for req_id {req_id} failed. "
+
f"matched gpu block num: {match_gpu_blocks_num} require extra gpu block num: "
+
f"{expected_block_num - match_gpu_blocks_num} > free block num: {len(self.gpu_free_block_list)}"
+ f"matched gpu block num: {match_gpu_blocks_num} require extra gpu block num: "
+ f"{expected_block_num - match_gpu_blocks_num} > free block num: {len(self.gpu_free_block_list)}"
)
logger.info(msg)
raise Exception("Not enough GPU memory to allocate cache")
def _prepare_cpu_cache(self, req_id, swap_node_ids, gpu_recv_block_ids, \
cpu_recv_block_ids, match_cpu_block_ids):
def _prepare_cpu_cache(
self,
req_id,
swap_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
):
"""
将cpu cache转移到GPU
"""
@@ -357,11 +356,8 @@ class PrefixCacheManager:
for tmp_cpu_block_id in match_cpu_block_ids:
need_transfer_task_cpu_block_ids.append(tmp_cpu_block_id)
assert len(need_transfer_task_gpu_block_ids) == len(
need_transfer_task_cpu_block_ids)
logger.info(
f"request_block_ids: req_id {req_id} issue_swap_task transfer_task_id {transfer_task_id}"
)
assert len(need_transfer_task_gpu_block_ids) == len(need_transfer_task_cpu_block_ids)
logger.info(f"request_block_ids: req_id {req_id} issue_swap_task transfer_task_id {transfer_task_id}")
self.issue_swap_task(
transfer_task_id,
swap_node_ids,
@@ -371,8 +367,16 @@ class PrefixCacheManager:
True,
)
def _prepare_cache(self, req_id, input_ids, block_size, \
expected_block_num, match_gpu_block_ids, match_cpu_block_ids, match_node_ids):
def _prepare_cache(
self,
req_id,
input_ids,
block_size,
expected_block_num,
match_gpu_block_ids,
match_cpu_block_ids,
match_node_ids,
):
"""
prepare cache for request
"""
@@ -394,8 +398,13 @@ class PrefixCacheManager:
gpu_extra_block_ids = self.allocate_gpu_blocks(gpu_extra_block_num)
if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(req_id, match_node_ids, gpu_recv_block_ids, \
cpu_recv_block_ids, match_cpu_block_ids)
self._prepare_cpu_cache(
req_id,
match_node_ids,
gpu_recv_block_ids,
cpu_recv_block_ids,
match_cpu_block_ids,
)
return gpu_recv_block_ids, gpu_extra_block_ids
@@ -423,9 +432,7 @@ class PrefixCacheManager:
self.metrics.req_count += 1
input_ids = task.prompt_token_ids
req_id = task.request_id
logger.info(
f"request_block_ids: start to allocate blocks for req_id {req_id}"
)
logger.info(f"request_block_ids: start to allocate blocks for req_id {req_id}")
input_token_num = len(input_ids)
common_block_ids = []
unique_block_ids = []
@@ -443,34 +450,43 @@ class PrefixCacheManager:
matched_block_num = match_gpu_blocks_num + match_cpu_blocks_num
matched_token_num_in_cpu_and_gpu = gpu_match_token_num + cpu_match_token_num
# check enough gpu memory to allocate cache
block_num = (input_token_num + block_size - 1 +
dec_token_num) // block_size
block_num = (input_token_num + block_size - 1 + dec_token_num) // block_size
self._check_validity(req_id, matched_block_num, block_num)
# update matched node info
current_time = time.time()
self._update_matched_node_info(req_id, match_block_node,
current_time)
self._update_matched_node_info(req_id, match_block_node, current_time)
# 2. prepare cache
gpu_recv_block_ids, gpu_extra_block_ids, = self._prepare_cache(req_id, \
input_ids, block_size, block_num, match_gpu_block_ids, match_cpu_block_ids, swap_node_ids)
(gpu_recv_block_ids, gpu_extra_block_ids,) = self._prepare_cache(
req_id,
input_ids,
block_size,
block_num,
match_gpu_block_ids,
match_cpu_block_ids,
swap_node_ids,
)
# update matched token num
matched_block_num = (gpu_match_token_num + cpu_match_token_num)
matched_block_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
unique_block_ids = gpu_extra_block_ids
dec_block_num = dec_token_num // block_size
left_input_ids = input_ids[
matched_token_num_in_cpu_and_gpu:] # 没在前缀树中的token
left_input_ids = input_ids[matched_token_num_in_cpu_and_gpu:] # 没在前缀树中的token
gpu_build_path_block_ids = []
gpu_build_path_block_ids = gpu_extra_block_ids
leaf_node = self.build_path(req_id, current_time, input_ids,
leaf_node = self.build_path(
req_id,
current_time,
input_ids,
left_input_ids,
gpu_build_path_block_ids,
block_size, match_block_node,
dec_block_num)
block_size,
match_block_node,
dec_block_num,
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
# 3. update metrics
@@ -482,17 +498,15 @@ class PrefixCacheManager:
gpu_match_token_num,
input_token_num,
)
hit_info[
"gpu_cache_blocks"] = gpu_match_token_num // block_size
hit_info[
"cpu_cache_blocks"] = cpu_match_token_num // block_size
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
hit_info["cpu_cache_blocks"] = cpu_match_token_num // block_size
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(
f"request_block_ids: request block for req_id {req_id}: common_block_ids "
+
f"{common_block_ids}, unique_block_ids {unique_block_ids}")
+ f"{common_block_ids}, unique_block_ids {unique_block_ids}"
)
return common_block_ids, unique_block_ids, hit_info
except Exception as e:
logger.error(f"request_block_ids: error: {type(e)} {e}")
@@ -523,25 +537,21 @@ class PrefixCacheManager:
node.decrement_shared_count()
node = node.parent
logger.info(
f"release_block_ids: req_id {req_id} leaf_node {leaf_node}"
)
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
if leaf_node == self.radix_tree_root:
self.recycle_gpu_blocks(
self.unfilled_req_block_map[req_id])
self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id])
del self.unfilled_req_block_map[req_id]
return
if leaf_node in self.gpu_lru_leaf_set:
return
if (leaf_node.shared_count == 0 and leaf_node.is_gpu_leaf_node
and leaf_node.is_persistent is False):
if leaf_node.shared_count == 0 and leaf_node.is_gpu_leaf_node and leaf_node.is_persistent is False:
self.gpu_lru_leaf_set.add(leaf_node)
heapq.heappush(self.gpu_lru_leaf_heap, leaf_node)
logger.info(
f"release_block_ids: req_id {req_id} has been finished, " +
f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}"
f"release_block_ids: req_id {req_id} has been finished, "
+ f"current gpu_lru_leaf_heap length {len(self.gpu_lru_leaf_heap)}"
)
return
except Exception as e:
@@ -563,8 +573,15 @@ class PrefixCacheManager:
node.reverved_dec_block_ids = []
self.recycle_gpu_blocks(node.block_id)
def _handle_free_gpu_node_with_cpu(self, node, hash_value_input_ids_map, \
hash_value_depth_map, need_recycle_gpu_block_ids, hash_value_gpu_block_ids_map, hash_value_swap_node_ids_map):
def _handle_free_gpu_node_with_cpu(
self,
node,
hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
):
"""
GPU node eviction in hierarchical cache layers
"""
@@ -573,14 +590,19 @@ class PrefixCacheManager:
node.reverved_dec_block_ids = []
need_recycle_gpu_block_ids.append(node.block_id)
hash_value_gpu_block_ids_map[node.input_hash_value].append(
node.block_id)
hash_value_swap_node_ids_map[node.input_hash_value].append(
node.node_id)
hash_value_gpu_block_ids_map[node.input_hash_value].append(node.block_id)
hash_value_swap_node_ids_map[node.input_hash_value].append(node.node_id)
def _evict_cache_async(self, future, total_gpu_free_count, \
hash_value_gpu_block_ids_map, hash_value_block_ids_map, \
hash_value_swap_node_ids_map, hash_value_input_ids_map, hash_value_depth_map):
def _evict_cache_async(
self,
future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
):
"""
evict cache async (GPU --> CPU)
"""
@@ -592,23 +614,21 @@ class PrefixCacheManager:
need_transfer_task_cpu_block_ids = []
cpu_block_ids = self.allocate_cpu_blocks(total_gpu_free_count)
for input_hash_value in hash_value_gpu_block_ids_map.keys():
need_transfer_task_gpu_block_ids.extend(
reversed(hash_value_gpu_block_ids_map[input_hash_value]))
need_transfer_task_gpu_block_ids.extend(reversed(hash_value_gpu_block_ids_map[input_hash_value]))
all_allocated_cpu_block_ids = []
for _ in reversed(hash_value_gpu_block_ids_map[input_hash_value]):
cpu_block_id_t = cpu_block_ids.pop(0)
all_allocated_cpu_block_ids.append(cpu_block_id_t)
need_transfer_task_cpu_block_ids.append(cpu_block_id_t)
swap_node_ids.extend(
reversed(hash_value_swap_node_ids_map[input_hash_value]))
swap_node_ids.extend(reversed(hash_value_swap_node_ids_map[input_hash_value]))
logger.info(
"free_block_ids_async: issue transfer task: " +
f"transfer_task_id {transfer_task_id}: " +
f"swap_node_ids {swap_node_ids} need_transfer_task_gpu_block_ids "
+
f"{need_transfer_task_gpu_block_ids}, need_transfer_task_cpu_block_ids "
+ f"{need_transfer_task_cpu_block_ids}, CacheStatus.SWAP2CPU")
"free_block_ids_async: issue transfer task: "
+ f"transfer_task_id {transfer_task_id}: "
+ f"swap_node_ids {swap_node_ids} need_transfer_task_gpu_block_ids "
+ f"{need_transfer_task_gpu_block_ids}, need_transfer_task_cpu_block_ids "
+ f"{need_transfer_task_cpu_block_ids}, CacheStatus.SWAP2CPU"
)
self.issue_swap_task(
transfer_task_id,
swap_node_ids,
@@ -619,9 +639,8 @@ class PrefixCacheManager:
)
logger.info(
"free_block_ids_async: after free, " +
f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}")
return
"free_block_ids_async: after free, " + f"len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
def free_block_ids_async(self, need_block_num):
"""
@@ -654,8 +673,10 @@ class PrefixCacheManager:
break
node = heapq.heappop(self.gpu_lru_leaf_heap)
self.gpu_lru_leaf_set.remove(node)
if not self.cache_config.enable_hierarchical_cache or \
self.cache_config.num_cpu_blocks < need_block_num:
if (
not self.cache_config.enable_hierarchical_cache
or self.cache_config.num_cpu_blocks < need_block_num
):
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
total_gpu_free_count += 1
@@ -666,12 +687,13 @@ class PrefixCacheManager:
if not node.children:
if node in self.gpu_lru_leaf_set:
continue
if (node != self.radix_tree_root
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False):
heapq.heappush(self.gpu_lru_leaf_heap,
node)
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
else:
continue
@@ -680,18 +702,25 @@ class PrefixCacheManager:
node.cache_status = CacheStatus.SWAP2CPU
else:
continue
self._handle_free_gpu_node_with_cpu(node, hash_value_input_ids_map, \
hash_value_depth_map, need_recycle_gpu_block_ids, \
hash_value_gpu_block_ids_map, hash_value_swap_node_ids_map)
self._handle_free_gpu_node_with_cpu(
node,
hash_value_input_ids_map,
hash_value_depth_map,
need_recycle_gpu_block_ids,
hash_value_gpu_block_ids_map,
hash_value_swap_node_ids_map,
)
total_gpu_free_count += 1
node = node.parent
if node in self.gpu_lru_leaf_set:
continue
if (node != self.radix_tree_root
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_gpu_leaf_node
and node.is_persistent is False):
and node.is_persistent is False
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
@@ -702,12 +731,16 @@ class PrefixCacheManager:
cpu_free_count = total_gpu_free_count
if cpu_free_count < need_block_num:
cpu_free_count = need_block_num
cpu_free_future = self.free_cpu_executor_pool.submit(
self.free_cpu_block_ids, cpu_free_count)
cpu_free_future = self.free_cpu_executor_pool.submit(self.free_cpu_block_ids, cpu_free_count)
self.gpu_free_task_future = self.free_gpu_executor_pool.submit(
self._evict_cache_async, cpu_free_future, total_gpu_free_count, \
hash_value_gpu_block_ids_map, hash_value_block_ids_map, \
hash_value_swap_node_ids_map, hash_value_input_ids_map, hash_value_depth_map
self._evict_cache_async,
cpu_free_future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
hash_value_swap_node_ids_map,
hash_value_input_ids_map,
hash_value_depth_map,
)
else:
self.gpu_free_task_future = None
@@ -724,10 +757,7 @@ class PrefixCacheManager:
Returns:
- freed_block_num: Number of CPU blocks successfully evicted
"""
hash_value_input_ids_map = {}
hash_value_block_ids_map = defaultdict(list)
hash_value_depth_map = {}
need_recycle_cpu_block_ids = []
total_cpu_free_count = 0
with self.request_release_lock:
while True:
@@ -739,13 +769,10 @@ class PrefixCacheManager:
node = heapq.heappop(self.cpu_lru_leaf_heap)
self.cpu_lru_leaf_set.remove(node)
tmp_block_ids = []
if (node.shared_count == 0
and node.cache_status == CacheStatus.CPU
and node.is_cpu_leaf_node):
if node.shared_count == 0 and node.cache_status == CacheStatus.CPU and node.is_cpu_leaf_node:
self.recycle_cpu_blocks(node.block_id)
hash_value_block_ids_map[node.input_hash_value].extend(
reversed(tmp_block_ids))
hash_value_block_ids_map[node.input_hash_value].extend(reversed(tmp_block_ids))
logger.info(f"free_cpu_block_ids: free node {node}")
self.node_id_pool.append(node.node_id)
@@ -759,15 +786,17 @@ class PrefixCacheManager:
if not node.children:
if node in self.cpu_lru_leaf_set:
continue
if (node != self.radix_tree_root
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU):
and node.cache_status == CacheStatus.CPU
):
heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node)
logger.info(
"free_cpu_block_ids: after free, " +
f"len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}")
"free_cpu_block_ids: after free, " + f"len(self.cpu_free_block_list) {len(self.cpu_free_block_list)}"
)
return total_cpu_free_count
def cal_block_hash(self, block):
@@ -807,8 +836,7 @@ class PrefixCacheManager:
with self.cache_status_lock:
while match_token_num < total_token_num:
token_block = input_ids[match_token_num:match_token_num +
block_size]
token_block = input_ids[match_token_num : match_token_num + block_size]
token_num = len(token_block)
if token_num != block_size:
break
@@ -817,11 +845,11 @@ class PrefixCacheManager:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
match_node_ids.append(child.node_id)
if (child in self.gpu_lru_leaf_set):
if child in self.gpu_lru_leaf_set:
self.gpu_lru_leaf_set.remove(child)
self.gpu_lru_leaf_heap.remove(child)
has_modified_gpu_lru_leaf_heap = True
elif (child in self.cpu_lru_leaf_set):
elif child in self.cpu_lru_leaf_set:
self.cpu_lru_leaf_set.remove(child)
self.cpu_lru_leaf_heap.remove(child)
has_modified_cpu_lru_leaf_heap = True
@@ -831,8 +859,9 @@ class PrefixCacheManager:
else:
if child.cache_status == CacheStatus.SWAP2CPU:
logger.info(
f"match_block: req_id {req_id} matched node" +
f" {child.node_id} which is being SWAP2CPU")
f"match_block: req_id {req_id} matched node"
+ f" {child.node_id} which is being SWAP2CPU"
)
child.cache_status = CacheStatus.GPU
match_gpu_block_ids.append(child.block_id)
gpu_match_token_num += block_size
@@ -851,8 +880,7 @@ class PrefixCacheManager:
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)
logger.info(
f"match_block: req_id {req_id} matched nodes: {match_node_ids}")
logger.info(f"match_block: req_id {req_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
match_cpu_block_ids,
@@ -873,9 +901,17 @@ class PrefixCacheManager:
node.req_id_set.add(req_id)
node = node.parent
def build_path(self, req_id, current_time, input_ids, left_input_ids,
gpu_block_ids, block_size, last_node,
reverved_dec_block_num):
def build_path(
self,
req_id,
current_time,
input_ids,
left_input_ids,
gpu_block_ids,
block_size,
last_node,
reverved_dec_block_num,
):
"""
Build path for blocks beyond the common prefix
Parameters:
@@ -915,7 +951,8 @@ class PrefixCacheManager:
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
new_last_node = BlockNode(node_id,
new_last_node = BlockNode(
node_id,
input_ids,
input_hash_value,
node.depth + 1,
@@ -925,7 +962,8 @@ class PrefixCacheManager:
current_time,
parent=node,
shared_count=1,
reverved_dec_block_ids=[])
reverved_dec_block_ids=[],
)
new_last_node.req_id_set.add(req_id)
self.node_map[node_id] = new_last_node
node.children[hash_value] = new_last_node
@@ -939,46 +977,44 @@ class PrefixCacheManager:
self.unfilled_req_block_map[req_id] = reverved_dec_block_ids
else:
new_last_node.reverved_dec_block_ids.extend(reverved_dec_block_ids)
logger.info(
f"build_path: allocate unique node ids {unique_node_ids} for req_id {req_id}"
)
logger.info(f"build_path: allocate unique node ids {unique_node_ids} for req_id {req_id}")
return new_last_node
def _handle_swap_result(self, swap_node_id, task_gpu_block_id,
task_cpu_block_id, event_type):
def _handle_swap_result(self, swap_node_id, task_gpu_block_id, task_cpu_block_id, event_type):
"""
handle swap resuha
"""
if swap_node_id is None:
return
with self.cache_status_lock:
if (event_type.value == CacheStatus.SWAP2CPU.value):
if event_type.value == CacheStatus.SWAP2CPU.value:
gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id
node = self.node_map[swap_node_id]
if node.cache_status.value == CacheStatus.GPU.value:
logger.info(
f"recv_data_transfer_result: node {node.node_id} " +
f"has been reused when SWAP2CPU, recycle cpu block id {cpu_block_id}"
f"recv_data_transfer_result: node {node.node_id} "
+ f"has been reused when SWAP2CPU, recycle cpu block id {cpu_block_id}"
)
self.recycle_cpu_blocks(cpu_block_id)
else:
node.cache_status = CacheStatus.CPU
node.block_id = cpu_block_id
if (node != self.radix_tree_root and node.shared_count == 0
if (
node != self.radix_tree_root
and node.shared_count == 0
and node.is_cpu_leaf_node
and node.cache_status == CacheStatus.CPU):
and node.cache_status == CacheStatus.CPU
):
if node not in self.cpu_lru_leaf_set:
heapq.heappush(self.cpu_lru_leaf_heap, node)
self.cpu_lru_leaf_set.add(node)
self.recycle_gpu_blocks(gpu_block_id)
logger.info(
f"recv_data_transfer_result: after SWAP2CPU, node {node}"
)
logger.info(f"recv_data_transfer_result: after SWAP2CPU, node {node}")
elif (event_type.value == CacheStatus.SWAP2GPU.value):
elif event_type.value == CacheStatus.SWAP2GPU.value:
gpu_block_id = task_gpu_block_id
cpu_block_id = task_cpu_block_id
@@ -987,12 +1023,12 @@ class PrefixCacheManager:
node.block_id = gpu_block_id
self.recycle_cpu_blocks(cpu_block_id)
logger.info(
f"recv_data_transfer_result: after SWAP2GPU, node {node}")
logger.info(f"recv_data_transfer_result: after SWAP2GPU, node {node}")
else:
logger.warning(
f"recv_data_transfer_result: Get unexpected event type {event_type}"
+ ", only SWAP2CPU and SWAP2GPU supported")
+ ", only SWAP2CPU and SWAP2GPU supported"
)
def recv_data_transfer_result(self):
"""
@@ -1024,10 +1060,8 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+
f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+
f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}")

View File

@@ -13,5 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .ipc_cache_transfer import IPCCommManager
from .rdma_cache_transfer import RDMACommManager
__all__ = ["IPCCommManager", "RDMACommManager"]

View File

@@ -14,13 +14,13 @@
# limitations under the License.
"""
import os
import paddle
from fastdeploy.model_executor.ops.gpu import (
get_data_ptr_ipc, ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync)
get_data_ptr_ipc,
ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
)
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
@@ -44,17 +44,13 @@ class IPCConnector:
self.rank_id = rank_id_
self.local_gpu_id = int(local_gpu_id_)
tmp = paddle.ones([1, 1])
logger.info(
f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}"
)
logger.info(f"init ipc rank{self.rank_id} with remote {self.remote_gpu_id} {self.local_gpu_id}")
for layer_id in range(layer_num):
key_unique_name = f"key_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
value_unique_name = f"value_caches_{layer_id}_rank{self.rank_id}.device{self.remote_gpu_id}"
self.remote_key_tensor_ptr_list.append(
get_data_ptr_ipc(tmp, key_unique_name))
self.remote_value_tensor_ptr_list.append(
get_data_ptr_ipc(tmp, value_unique_name))
self.write_stream = paddle.device.Stream(f'gpu:{self.local_gpu_id}')
self.remote_key_tensor_ptr_list.append(get_data_ptr_ipc(tmp, key_unique_name))
self.remote_value_tensor_ptr_list.append(get_data_ptr_ipc(tmp, value_unique_name))
self.write_stream = paddle.device.Stream(f"gpu:{self.local_gpu_id}")
self.finish_event = paddle.device.Event()
@@ -83,14 +79,11 @@ class IPCCommManager:
"""
Connect to remote gpu.
"""
logger.info(
f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}"
)
logger.info(f"{self.rank_id}: connect to remote_gpu_id:{remote_gpu_id_} {self.layer_num} {self.gpu_idx}")
if self.is_connected(remote_gpu_id_):
return True
else:
self.comm_map[remote_gpu_id_] = IPCConnector(
self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
self.comm_map[remote_gpu_id_] = IPCConnector(self.rank_id, remote_gpu_id_, self.layer_num, self.gpu_idx)
return True
def is_connected(self, remote_gpu_id_=0):
@@ -102,8 +95,7 @@ class IPCCommManager:
else:
return False
def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids,
layer_idx):
def write_cache(self, ip, remote_gpu_id, local_block_ids, remote_block_ids, layer_idx):
"""
Connect to remote gpu and write cache.
"""
@@ -114,20 +106,26 @@ class IPCCommManager:
with paddle.device.stream_guard(comm.write_stream):
ipc_sent_key_value_cache_by_remote_ptr(
self.local_key_cache_tensor_list[layer_idx],
self.local_value_cache_tensor_list[layer_idx], local_block_ids,
remote_block_ids, comm.remote_key_tensor_ptr_list[layer_idx],
comm.remote_value_tensor_ptr_list[layer_idx], block_num,
self.gpu_idx, comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream)
self.local_value_cache_tensor_list[layer_idx],
local_block_ids,
remote_block_ids,
comm.remote_key_tensor_ptr_list[layer_idx],
comm.remote_value_tensor_ptr_list[layer_idx],
block_num,
self.gpu_idx,
comm.remote_gpu_id,
comm.write_stream.stream_base.cuda_stream,
)
return 0
def write_block_by_sync(self, remote_gpu_id):
"""
check finish event and wait for it
"""
paddle.set_device(f'gpu:{self.gpu_idx}')
paddle.set_device(f"gpu:{self.gpu_idx}")
comm = self.comm_map[remote_gpu_id]
ipc_sent_key_value_cache_by_remote_ptr_block_sync(
self.local_key_cache_tensor_list[0], # tensor no use
self.local_value_cache_tensor_list[0], # tensor no use
comm.write_stream.stream_base.cuda_stream)
comm.write_stream.stream_base.cuda_stream,
)

View File

@@ -42,11 +42,13 @@ Bandwidth Saturation Capability: Under multi-threaded high-pressure scenarios, b
### Dependencies Installation
#### Python Packages
```bash
pip install pyzmq pybind11[global]
```
#### System Libraries (Linux)
```bash
# Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev
@@ -107,8 +109,6 @@ pip install dist/*.whl
|----------|---------|-------------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | Enable GDRCopy flush for Ampere GPUs |
# Set RDMA GID index
export KVCACHE_RDMA_GID_INDEX=3
@@ -125,7 +125,6 @@ export KVCACHE_DEBUG=1
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## Network configurations
kvcache transfer is fully tested with RDMA over Converged Ethernet (RoCE) networks. However, it is theoretically compatible with Infiniband as well.

View File

@@ -43,11 +43,13 @@
### 依赖安装
#### Python包
```bash
pip install pyzmq pybind11[global]
```
#### 系统库(Linux)
```bash
# Ubuntu/Debian
sudo apt-get install -y libibverbs-dev librdmacm-dev
@@ -108,7 +110,6 @@ pip install dist/*.whl
|------|--------|------|
| `KVCACHE_GDRCOPY_FLUSH_ENABLE` | false | 为Ampere GPU启用GDRCopy刷新 |
# 设置RDMA GID索引
export KVCACHE_RDMA_GID_INDEX=3
@@ -125,7 +126,6 @@ export KVCACHE_DEBUG=1
export KVCACHE_DEBUG_FILE=/var/log/kvcache_debug.log
export KVCACHE_ERROR_FILE=/var/log/kvcache_error.log
## 网络配置
kvcache transfer已通过RDMA over Converged Ethernet (RoCE)网络全面测试。理论上也兼容Infiniband。

View File

@@ -24,13 +24,24 @@ class RDMACommManager:
RDMACommManager to manage rdma communication
"""
def __init__(self, splitwise_role, rank, gpu_id, cache_k_ptr_list, \
cache_v_ptr_list, max_block_num, block_bytes, rdma_port):
def __init__(
self,
splitwise_role,
rank,
gpu_id,
cache_k_ptr_list,
cache_v_ptr_list,
max_block_num,
block_bytes,
rdma_port,
):
try:
import rdma_comm
except:
logger.error(f"The installation of the RDMA library failed." \
"Confirm whether your network card supports RDMA transmission.")
logger.error(
"The installation of the RDMA library failed."
"Confirm whether your network card supports RDMA transmission."
)
return
self.messager = rdma_comm.RDMACommunicator(
splitwise_role,
@@ -50,7 +61,7 @@ class RDMACommManager:
Connect to remote gpu and write cache.
"""
assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{str(port)}"
addr = f"{ip}:{port!s}"
if addr in self.connected_rdma:
return True
ret = self.messager.is_connected(ip, str(port))
@@ -59,18 +70,13 @@ class RDMACommManager:
return True
ret = self.messager.connect(ip, str(port))
logger.info(
f"connect to remote rdma address {ip}:{port} status is {ret}")
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0:
self.connected_rdma.add(addr)
return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids,
layer_idx):
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):
"""
Connect to remote gpu and write cache.
"""
return self.messager.write_cache(ip, str(port), local_block_ids,
remote_block_ids, layer_idx)
return self.messager.write_cache(ip, str(port), local_block_ids, remote_block_ids, layer_idx)

View File

@@ -24,12 +24,12 @@ from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantConfigBase
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.utils import get_logger
logger = get_logger("config", "config.log")
class MoEPhase(Enum):
"""
The generation phase of the moe.
@@ -38,13 +38,14 @@ class MoEPhase(Enum):
PREFILL = 1
DECODER = 2
class ErnieArchitectures:
"""Helper class for ERNIE architecture check."""
ARCHITECTURES = {
"Ernie4_5_ForCausalLM",
"Ernie4_5_MoeForCausalLM",
"Ernie4_5_VLMoeForConditionalGeneration"
"Ernie4_5_VLMoeForConditionalGeneration",
}
@classmethod
@@ -57,6 +58,7 @@ class ErnieArchitectures:
"""Check if the given architecture is an ERNIE architecture."""
return architecture in cls.ARCHITECTURES
PRETRAINED_INIT_CONFIGURATION = {
"rope_theta": 10000.0,
"num_key_value_heads": -1,
@@ -81,6 +83,7 @@ class ModelConfig:
"""
The configuration class to store the configuration of a `LLM`.
"""
def __init__(
self,
args,
@@ -134,6 +137,7 @@ class ModelConfig:
class ParallelConfig:
"""Configuration for the distributed execution."""
def __init__(
self,
args,
@@ -213,10 +217,8 @@ class ParallelConfig:
self.enable_custom_all_reduce: bool = False
# pd_disaggregation
use_pd_disaggregation: int = int(
os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(
os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
use_pd_disaggregation_per_chunk: int = int(os.getenv("FLAGS_use_pd_disaggregation_per_chunk", 0))
if use_pd_disaggregation_per_chunk:
self.pd_disaggregation_mode = "per_chunk"
elif use_pd_disaggregation:
@@ -224,10 +226,12 @@ class ParallelConfig:
else:
self.pd_disaggregation_mode = "None"
class SpeculativeConfig:
"""
Configuration for speculative decoding.
"""
def __init__(
self,
args,
@@ -263,20 +267,24 @@ class SpeculativeConfig:
# TODO(YuanRisheng): The name of the server args is different from the name of the SpeculativeConfig.
# We temperately add the name map here and will delete it in future.
name_map = {"speculative_method": "method",
name_map = {
"speculative_method": "method",
"speculative_max_draft_token_num": "num_speculative_tokens",
"speculative_model_name_or_path": "model_name_or_path",
"speculative_model_quantization": "quantization",
"speculative_benchmark_mode": "benchmark_mode"}
"speculative_benchmark_mode": "benchmark_mode",
}
for key, value in args.items():
if key in name_map.keys() and hasattr(self, name_map[key]):
setattr(self, name_map[key], value)
class DeviceConfig:
"""
Configuration for device settings.
"""
def __init__(
self,
args,
@@ -286,6 +294,7 @@ class DeviceConfig:
if hasattr(self, key):
setattr(self, key, value)
@dataclass
class GraphOptimizationConfig:
"""
@@ -336,15 +345,10 @@ class GraphOptimizationConfig:
full_cuda_graph: bool = True
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
batch_size_to_captured_size: dict[int, int] = field(default=None, init=False) # type: ignore
# CINN Config ...
def init_with_cudagrpah_size(
self,
max_num_seqs:int = 0
) -> None:
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
@@ -353,32 +357,28 @@ class GraphOptimizationConfig:
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
self.cudagraph_capture_sizes, dedup_sizes)
logger.info(
("cudagraph sizes specified by model runner" " %s is overridden by config %s"),
self.cudagraph_capture_sizes,
dedup_sizes,
)
self.cudagraph_capture_sizes = dedup_sizes
# Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0
self.max_capture_size = self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0
# Pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
for end, start in zip(self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0]):
for bs in range(start, end):
if bs == start:
self.batch_size_to_captured_size[bs] = start
else:
self.batch_size_to_captured_size[bs] = end
self.batch_size_to_captured_size[
self.max_capture_size] = self.max_capture_size
self.batch_size_to_captured_size[self.max_capture_size] = self.max_capture_size
def _set_cudagraph_sizes(
self,
max_num_seqs:int = 0
):
def _set_cudagraph_sizes(self, max_num_seqs: int = 0):
"""
Calculate a series of candidate capture batch sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
@@ -405,24 +405,28 @@ class LoadConfig:
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- None: No dynamic loading
"""
def __init__(
self,
args,
):
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal['ipc', 'ipc_snapshot']] = None
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
class LoRAConfig:
"""LoRA Config"""
pass
class KVCacheConfig:
"""KV Cache Config"""
cache_quant_dtype: str = "none"
@@ -430,6 +434,7 @@ class DecodingConfig:
"""
Configuration for decoding
"""
def __init__(
self,
args,
@@ -439,26 +444,24 @@ class DecodingConfig:
if hasattr(self, key):
setattr(self, key, value)
@dataclass
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None, init=True)
speculative_config: SpeculativeConfig = field(default=None,
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore
device_config: DeviceConfig = field(default=None, init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True)
quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None
decoding_config: DecodingConfig = field(default=None,
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None, init=True) # type: ignore
def __post_init__(self):
# Initialize cuda graph capture list

View File

@@ -22,8 +22,6 @@ model_name_or_path = "./models/llama-7b"
# 超参设置
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.generate(prompts="who are you",
use_tqdm=True,
sampling_params=sampling_params)
output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
print(output)

View File

@@ -14,19 +14,15 @@
# limitations under the License.
"""
import time
import os
import multiprocessing
import os
import time
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.engine.sampling_params import SamplingParams
model_name_or_path = "baidu/ERNIE-4.5-21B-A3B-Paddle"
def start_decode(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["FD_LOG_DIR"] = "log_decode"
@@ -36,14 +32,15 @@ def start_decode(model_name_or_path):
splitwise_role="decode",
engine_worker_queue_port=6678,
innode_prefill_ports=[6676],
cache_queue_port=55668
cache_queue_port=55668,
)
return llm_decode
def start_prefill(model_name_or_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FD_LOG_DIR"] = "log_prefill"
llm_prefill = LLM(
LLM(
model=model_name_or_path,
tensor_parallel_size=1,
splitwise_role="prefill",
@@ -53,16 +50,14 @@ def start_prefill(model_name_or_path):
def main():
prefill = multiprocessing.Process(
target=start_prefill,
args=(model_name_or_path,)).start()
prefill = multiprocessing.Process(target=start_prefill, args=(model_name_or_path,)).start()
time.sleep(10)
llm_decode = start_decode(model_name_or_path)
output = llm_decode.generate(prompts=["who are you", "what can you do"], use_tqdm=True)
print(output)
decode.join()
prefill.join()
if __name__ == "__main__":

View File

@@ -14,7 +14,6 @@
# limitations under the License.
"""
import openai
ip = "0.0.0.0"
@@ -42,7 +41,7 @@ response = client.completions.create(
)
for chunk in response:
print(chunk.choices[0].text, end='')
print(chunk.choices[0].text, end="")
print("\n")
# Chat completion
@@ -78,5 +77,5 @@ response = client.chat.completions.create(
for chunk in response:
if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta.content, end='')
print(chunk.choices[0].delta.content, end="")
print("\n")

View File

@@ -14,14 +14,12 @@
# limitations under the License.
"""
import openai
print("hello")
ip = "0.0.0.0"
service_http_port = "9809"
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1",
api_key="EMPTY_API_KEY")
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
print("world")
# 非流式对话
@@ -30,23 +28,21 @@ response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful AI assistant."
"content": "You are a helpful AI assistant.",
}, # system不是必需可选
{
"role":
"user",
"content": [{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"detail": "high"
}
}, {
"type": "text",
"text": "请描述图片内容"
}]
}
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"detail": "high",
},
},
{"type": "text", "text": "请描述图片内容"},
],
},
],
temperature=1,
max_tokens=53,
@@ -60,30 +56,25 @@ response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful AI assistant."
"content": "You are a helpful AI assistant.",
}, # system不是必需可选
{
"role": "user",
"content": "List 3 countries and their capitals."
},
{"role": "user", "content": "List 3 countries and their capitals."},
{
"role": "assistant",
"content": "China(Beijing), France(Paris), Australia(Canberra)."
"content": "China(Beijing), France(Paris), Australia(Canberra).",
},
{
"role":
"user",
"content": [{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"detail": "high"
}
}, {
"type": "text",
"text": "请描述图片内容"
}]
"url": "https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0",
"detail": "high",
},
},
{"type": "text", "text": "请描述图片内容"},
],
},
],
temperature=1,
@@ -94,5 +85,5 @@ for chunk in response:
if chunk.choices[0].delta is not None:
# print(chunk.choices[0].delta, end='')
# print("\n")
print(chunk.choices[0].delta.content, end='')
print(chunk.choices[0].delta.content, end="")
print(response)

View File

@@ -17,21 +17,28 @@
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
_TP_AR = None
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
global _TP_AR
if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda():
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
@@ -42,5 +49,6 @@ try:
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
except:
tensor_model_parallel_all_reduce = None

View File

@@ -41,7 +41,7 @@ def find_loaded_library(lib_name) -> Optional[str]:
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
"""
found = False
with open("/proc/self/maps") as f:
for line in f:
@@ -73,18 +73,40 @@ class CudaRTLibrary:
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function(
"cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint]
"cudaMemset",
cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t],
),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind )
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr )
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags )
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[
ctypes.POINTER(ctypes.c_void_p),
cudaIpcMemHandle_t,
ctypes.c_uint,
],
),
]

View File

@@ -13,26 +13,26 @@
# limitations under the License.
from contextlib import contextmanager
import atexit
import ctypes
from contextlib import contextmanager
from typing import List, Optional
import paddle
import paddle.distributed as dist
from paddle.distributed.communication.group import Group
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
from fastdeploy.model_executor.ops.gpu import (
all_reduce,
dispose,
get_graph_buffer_ipc_meta,
init_custom_all_reduce,
meta_size,
register_buffer,
get_graph_buffer_ipc_meta,
register_graph_buffers,
)
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
try:
meta_size()
custom_ar = True
@@ -147,7 +147,12 @@ class CustomAllreduce:
return inp_size < self.max_size
return False
def all_reduce(self, inp: paddle.Tensor, out: paddle.Tensor = None, registered: bool = False):
def all_reduce(
self,
inp: paddle.Tensor,
out: paddle.Tensor = None,
registered: bool = False,
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
@@ -179,16 +184,12 @@ class CustomAllreduce:
def register_graph_buffers(self):
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore

View File

@@ -13,15 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
from fastdeploy.engine.config import (CacheConfig, Config,
GraphOptimizationConfig, ModelConfig,
ParallelConfig, SpeculativeConfig,
TaskOption)
from fastdeploy.engine.config import (
CacheConfig,
Config,
GraphOptimizationConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
TaskOption,
)
from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import FlexibleArgumentParser
@@ -323,365 +329,429 @@ class EngineArgs:
"""
# Model parameters group
model_group = parser.add_argument_group("Model Configuration")
model_group.add_argument("--model",
model_group.add_argument(
"--model",
type=str,
default=EngineArgs.model,
help="Model name or path to be used.")
model_group.add_argument("--model-config-name",
help="Model name or path to be used.",
)
model_group.add_argument(
"--model-config-name",
type=nullable_str,
default=EngineArgs.model_config_name,
help="The model configuration file name.")
help="The model configuration file name.",
)
model_group.add_argument(
"--tokenizer",
type=nullable_str,
default=EngineArgs.tokenizer,
help=
"Tokenizer name or path (defaults to model path if not specified)."
help="Tokenizer name or path (defaults to model path if not specified).",
)
model_group.add_argument(
"--max-model-len",
type=int,
default=EngineArgs.max_model_len,
help="Maximum context length supported by the model.")
help="Maximum context length supported by the model.",
)
model_group.add_argument(
"--block-size",
type=int,
default=EngineArgs.block_size,
help="Number of tokens processed in one block.")
model_group.add_argument("--task",
help="Number of tokens processed in one block.",
)
model_group.add_argument(
"--task",
type=str,
default=EngineArgs.task,
help="Task to be executed by the model.")
help="Task to be executed by the model.",
)
model_group.add_argument(
"--use-warmup",
type=int,
default=EngineArgs.use_warmup,
help="Flag to indicate whether to use warm-up before inference.")
help="Flag to indicate whether to use warm-up before inference.",
)
model_group.add_argument(
"--limit-mm-per-prompt",
default=EngineArgs.limit_mm_per_prompt,
type=json.loads,
help="Limitation of numbers of multi-modal data.")
help="Limitation of numbers of multi-modal data.",
)
model_group.add_argument(
"--mm-processor-kwargs",
default=EngineArgs.mm_processor_kwargs,
type=json.loads,
help="Additional keyword arguments for the multi-modal processor.")
model_group.add_argument("--enable-mm",
action='store_true',
help="Additional keyword arguments for the multi-modal processor.",
)
model_group.add_argument(
"--enable-mm",
action="store_true",
default=EngineArgs.enable_mm,
help="Flag to enable multi-modal model.")
model_group.add_argument("--reasoning-parser",
help="Flag to enable multi-modal model.",
)
model_group.add_argument(
"--reasoning-parser",
type=str,
default=EngineArgs.reasoning_parser,
help="Flag specifies the reasoning parser to use for extracting "\
"reasoning content from the model output")
help="Flag specifies the reasoning parser to use for extracting "
"reasoning content from the model output",
)
model_group.add_argument(
"--speculative-config",
type=json.loads,
default=EngineArgs.speculative_config,
help="Configuration for speculative execution.")
help="Configuration for speculative execution.",
)
model_group.add_argument(
"--dynamic-load-weight",
action='store_true',
action="store_true",
default=EngineArgs.dynamic_load_weight,
help="Flag to indicate whether to load weight dynamically.")
help="Flag to indicate whether to load weight dynamically.",
)
model_group.add_argument(
"--load-strategy",
type=str,
default=EngineArgs.load_strategy,
help="Flag to dynamic load strategy.")
model_group.add_argument("--engine-worker-queue-port",
help="Flag to dynamic load strategy.",
)
model_group.add_argument(
"--engine-worker-queue-port",
type=int,
default=EngineArgs.engine_worker_queue_port,
help="port for engine worker queue")
model_group.add_argument("--quantization",
help="port for engine worker queue",
)
model_group.add_argument(
"--quantization",
type=str,
default=EngineArgs.quantization,
help="Quantization name for the model, currentlly support " \
"'wint8', 'wint4'," \
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
model_group.add_argument("--use-cudagraph",
action='store_true',
help="Quantization name for the model, currentlly support "
"'wint8', 'wint4',"
"default is None. The priority of this configuration "
"is lower than that of the config file. "
"More complex quantization methods need to be configured via the config file.",
)
model_group.add_argument(
"--use-cudagraph",
action="store_true",
default=EngineArgs.use_cudagraph,
help="Flags to enable cuda graph.")
model_group.add_argument("--graph-optimization-config",
help="Flags to enable cuda graph.",
)
model_group.add_argument(
"--graph-optimization-config",
type=json.loads,
default=EngineArgs.graph_optimization_config,
help="")
model_group.add_argument("--guided-decoding-backend",
help="",
)
model_group.add_argument(
"--guided-decoding-backend",
type=str,
default=EngineArgs.guided_decoding_backend,
help="Guided Decoding Backend")
help="Guided Decoding Backend",
)
model_group.add_argument(
"--guided-decoding-disable-any-whitespace",
type=str,
default=EngineArgs.guided_decoding_disable_any_whitespace,
help=
"Disabled any whitespaces when using guided decoding backend XGrammar."
help="Disabled any whitespaces when using guided decoding backend XGrammar.",
)
model_group.add_argument("--enable-logprob",
model_group.add_argument(
"--enable-logprob",
action="store_true",
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities."
help="Enable output of token-level log probabilities.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
parallel_group.add_argument("--tensor-parallel-size",
parallel_group.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=EngineArgs.tensor_parallel_size,
help="Degree of tensor parallelism.")
parallel_group.add_argument("--enable-custom-all-reduce",
action='store_true',
help="Degree of tensor parallelism.",
)
parallel_group.add_argument(
"--enable-custom-all-reduce",
action="store_true",
default=EngineArgs.enable_custom_all_reduce,
help="Flag to enable custom all-reduce.")
help="Flag to enable custom all-reduce.",
)
parallel_group.add_argument(
"--max-num-seqs",
type=int,
default=EngineArgs.max_num_seqs,
help="Maximum number of sequences per iteration.")
help="Maximum number of sequences per iteration.",
)
parallel_group.add_argument(
"--num-gpu-blocks-override",
type=int,
default=EngineArgs.num_gpu_blocks_override,
help="Override for the number of GPU blocks.")
help="Override for the number of GPU blocks.",
)
parallel_group.add_argument(
"--max-num-batched-tokens",
type=int,
default=EngineArgs.max_num_batched_tokens,
help="Maximum number of tokens to batch together.")
help="Maximum number of tokens to batch together.",
)
parallel_group.add_argument(
"--gpu-memory-utilization",
type=float,
default=EngineArgs.gpu_memory_utilization,
help="Fraction of GPU memory to be utilized.")
help="Fraction of GPU memory to be utilized.",
)
parallel_group.add_argument("--data-parallel-size",
parallel_group.add_argument(
"--data-parallel-size",
type=int,
default=EngineArgs.data_parallel_size,
help="Degree of data parallelism.")
parallel_group.add_argument("--enable-expert-parallel",
action='store_true',
help="Degree of data parallelism.",
)
parallel_group.add_argument(
"--enable-expert-parallel",
action="store_true",
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.")
help="Enable expert parallelism.",
)
# CacheConfig parameters group
cache_group = parser.add_argument_group("Cache Configuration")
cache_group.add_argument("--kv-cache-ratio",
cache_group.add_argument(
"--kv-cache-ratio",
type=float,
default=EngineArgs.kv_cache_ratio,
help="Ratio of tokens to process in a block.")
help="Ratio of tokens to process in a block.",
)
cache_group.add_argument(
"--swap-space",
type=float,
default=EngineArgs.swap_space,
help="The amount of CPU memory to offload to.")
help="The amount of CPU memory to offload to.",
)
cache_group.add_argument("--cache-queue-port",
cache_group.add_argument(
"--cache-queue-port",
type=int,
default=EngineArgs.cache_queue_port,
help="port for cache queue")
cache_group.add_argument("--static-decode-blocks",
help="port for cache queue",
)
cache_group.add_argument(
"--static-decode-blocks",
type=int,
default=EngineArgs.static_decode_blocks,
help="Static decoding blocks num.")
help="Static decoding blocks num.",
)
# Cluster system parameters group
system_group = parser.add_argument_group("System Configuration")
system_group.add_argument(
"--dist-init-ip",
default=EngineArgs.dist_init_ip,
help=
"IP addresses of master node.")
help="IP addresses of master node.",
)
system_group.add_argument(
"--nnodes",
type=int,
default=EngineArgs.nnodes,
help=
"The number of all nodes.")
help="The number of all nodes.",
)
system_group.add_argument(
"--node-rank",
type=int,
default=EngineArgs.node_rank,
help=
"node rank id (range [0, nnodes)).")
help="node rank id (range [0, nnodes)).",
)
# Performance tuning parameters group
perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument("--enable-prefix-caching",
action='store_true',
perf_group.add_argument(
"--enable-prefix-caching",
action="store_true",
default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.")
help="Flag to enable prefix caching.",
)
perf_group.add_argument("--splitwise-role",
perf_group.add_argument(
"--splitwise-role",
type=str,
default=EngineArgs.splitwise_role,
help="Role of splitwise. Default is \
'mixed'. (prefill, decode, mixed)")
'mixed'. (prefill, decode, mixed)",
)
perf_group.add_argument("--innode-prefill-ports",
perf_group.add_argument(
"--innode-prefill-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.innode_prefill_ports,
help="port for innode prefill")
help="port for innode prefill",
)
perf_group.add_argument("--enable-chunked-prefill",
action='store_true',
perf_group.add_argument(
"--enable-chunked-prefill",
action="store_true",
default=EngineArgs.enable_chunked_prefill,
help="Flag to enable chunked prefill.")
perf_group.add_argument("--max-num-partial-prefills",
help="Flag to enable chunked prefill.",
)
perf_group.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, Maximum number \
of concurrent partial prefill requests.")
of concurrent partial prefill requests.",
)
perf_group.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help=
("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
"that will be prefilled concurrently."))
help=(
"For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
"that will be prefilled concurrently."
),
)
perf_group.add_argument(
"--long-prefill-token-threshold",
type=int,
default=EngineArgs.long_prefill_token_threshold,
help=("For chunked prefill, the threshold number of"
" tokens for a prompt to be considered long."))
help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long."),
)
perf_group.add_argument(
"--cache-transfer-protocol",
type=str,
default=EngineArgs.cache_transfer_protocol,
help="support protocol list, comma separated, default is ipc")
help="support protocol list, comma separated, default is ipc",
)
perf_group.add_argument("--pd-comm-port",
perf_group.add_argument(
"--pd-comm-port",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.pd_comm_port,
help="port for splitwise communication.")
help="port for splitwise communication.",
)
perf_group.add_argument("--rdma-comm-ports",
perf_group.add_argument(
"--rdma-comm-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.rdma_comm_ports,
help="ports for rdma communication.")
help="ports for rdma communication.",
)
# Scheduler parameters group
scheduler_group = parser.add_argument_group("Scheduler")
scheduler_group.add_argument(
"--scheduler-name",
default=EngineArgs.scheduler_name,
help=
f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)"
help=f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)",
)
scheduler_group.add_argument(
"--scheduler-max-size",
type=int,
default=EngineArgs.scheduler_max_size,
help=
f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)"
help=f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)",
)
scheduler_group.add_argument(
"--scheduler-ttl",
type=int,
default=EngineArgs.scheduler_ttl,
help=
f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
help=f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)",
)
scheduler_group.add_argument(
"--scheduler-host",
default=EngineArgs.scheduler_host,
help=
f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)"
help=f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)",
)
scheduler_group.add_argument(
"--scheduler-port",
type=int,
default=EngineArgs.scheduler_port,
help=
f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)")
help=f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)",
)
scheduler_group.add_argument(
"--scheduler-db",
type=int,
default=EngineArgs.scheduler_db,
help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)"
help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)",
)
scheduler_group.add_argument(
"--scheduler-password",
default=EngineArgs.scheduler_password,
help=
f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)"
help=f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)",
)
scheduler_group.add_argument(
"--scheduler-topic",
default=EngineArgs.scheduler_topic,
help=
f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
help=f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)",
)
scheduler_group.add_argument(
"--scheduler-min-load-score",
type=float,
default=EngineArgs.scheduler_min_load_score,
help=
f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
help=f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)",
)
scheduler_group.add_argument(
"--scheduler-load-shards-num",
type=int,
default=EngineArgs.scheduler_load_shards_num,
help=("Number of shards for load balancing table. Default is "
f"{EngineArgs.scheduler_load_shards_num} (global)"))
help=(
"Number of shards for load balancing table. Default is "
f"{EngineArgs.scheduler_load_shards_num} (global)"
),
)
scheduler_group.add_argument(
"--scheduler-sync-period",
type=int,
default=EngineArgs.scheduler_sync_period,
help=f"SplitWise Use, node load sync period, "
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)")
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)",
)
scheduler_group.add_argument(
"--scheduler-expire-period",
type=int,
default=EngineArgs.scheduler_expire_period,
help=f"SplitWise Use, node will not be scheduled after "
f"expire-period ms not sync load, Default is "
f"{EngineArgs.scheduler_expire_period}ms. (global)")
f"{EngineArgs.scheduler_expire_period}ms. (global)",
)
scheduler_group.add_argument(
"--scheduler-release-load-expire-period",
type=int,
default=EngineArgs.scheduler_release_load_expire_period,
help=f"SplitWise Use, scheduler will release req load after "
f"expire period(s). Default is "
f"{EngineArgs.scheduler_release_load_expire_period}. (global)")
f"{EngineArgs.scheduler_release_load_expire_period}. (global)",
)
scheduler_group.add_argument(
"--scheduler-reader-parallel",
type=int,
default=EngineArgs.scheduler_reader_parallel,
help=f"SplitWise Use, Results Reader Sync Parallel, "
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)")
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)",
)
scheduler_group.add_argument(
"--scheduler-writer-parallel",
type=int,
default=EngineArgs.scheduler_writer_parallel,
help=f"SplitWise Use, Results Writer Sync Parallel, "
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)")
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)",
)
scheduler_group.add_argument(
"--scheduler-reader-batch-size",
type=int,
default=EngineArgs.scheduler_reader_batch_size,
help=f"SplitWise Use, Results Reader Batch Size, "
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)")
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)",
)
scheduler_group.add_argument(
"--scheduler-writer-batch-size",
type=int,
default=EngineArgs.scheduler_writer_batch_size,
help=f"SplitWise Use, Results Writer Batch Size, "
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)")
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)",
)
return parser
@@ -690,21 +760,19 @@ class EngineArgs:
"""
Create an instance of EngineArgs from command line arguments.
"""
return cls(
**{
field.name: getattr(args, field.name)
for field in dataclass_fields(cls)
})
return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)})
def create_model_config(self) -> ModelConfig:
"""
Create and return a ModelConfig object based on the current settings.
"""
return ModelConfig(model_name_or_path=self.model,
return ModelConfig(
model_name_or_path=self.model,
config_json_file=self.model_config_name,
quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight,
load_strategy=self.load_strategy)
load_strategy=self.load_strategy,
)
def create_cache_config(self, model_cfg) -> CacheConfig:
"""
@@ -728,8 +796,7 @@ class EngineArgs:
)
def create_speculative_config(self) -> SpeculativeConfig:
"""
"""
""" """
if self.speculative_config is not None:
return SpeculativeConfig(**self.speculative_config)
else:
@@ -742,9 +809,11 @@ class EngineArgs:
prefix = "scheduler_"
prefix_len = len(prefix)
extra_params = [
"max_model_len", "enable_chunked_prefill",
"max_num_partial_prefills", "max_long_partial_prefills",
"long_prefill_token_threshold"
"max_model_len",
"enable_chunked_prefill",
"max_num_partial_prefills",
"max_long_partial_prefills",
"long_prefill_token_threshold",
]
all = asdict(self)
@@ -765,7 +834,7 @@ class EngineArgs:
tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
data_parallel_size=self.data_parallel_size,
enable_custom_all_reduce=self.enable_custom_all_reduce
enable_custom_all_reduce=self.enable_custom_all_reduce,
)
def create_graph_optimization_config(self) -> GraphOptimizationConfig:
@@ -782,8 +851,7 @@ class EngineArgs:
Create and return a Config object based on the current settings.
"""
model_cfg = self.create_model_config()
if not model_cfg.is_unified_ckpt and hasattr(model_cfg,
'tensor_parallel_size'):
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
self.tensor_parallel_size = model_cfg.tensor_parallel_size
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
@@ -795,11 +863,11 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
assert not (self.use_cudagraph and self.enable_prefix_caching), \
"Prefix caching cannot be used with CUDA graph"
assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"
assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \
"enable_custom_all_reduce must be used with tensor_parallel_size>1"
assert not (
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
return Config(
model_name_or_path=self.model,

View File

@@ -23,8 +23,14 @@ from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
is_port_available, get_random_port, llm_logger)
from fastdeploy.utils import (
ceil_div,
check_unified_ckpt,
get_host_ip,
get_random_port,
is_port_available,
llm_logger,
)
TaskOption = Literal["generate"]
@@ -39,13 +45,15 @@ class ModelConfig:
model_name_or_path (str): Name or path of the model.
"""
def __init__(self,
def __init__(
self,
model_name_or_path: str,
config_json_file: str = "config.json",
dynamic_load_weight: bool = False,
load_strategy: str = "ipc_snapshot",
quantization: str = None,
download_dir: Optional[str] = None):
download_dir: Optional[str] = None,
):
"""
Initialize the ModelConfig class.
@@ -64,11 +72,9 @@ class ModelConfig:
if os.path.isfile(model_name_or_path):
try:
from paddleformers.transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name_or_path)
config_dict = {
k: v
for k, v in vars(config).items() if not k.startswith('_')
}
config_dict = {k: v for k, v in vars(config).items() if not k.startswith("_")}
for key, value in config_dict.items():
setattr(self, key, value)
except Exception:
@@ -115,8 +121,7 @@ class ModelConfig:
if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False
if not hasattr(self, "head_dim"):
assert hasattr(self, "hidden_size") and hasattr(
self, "num_attention_heads")
assert hasattr(self, "hidden_size") and hasattr(self, "num_attention_heads")
self.head_dim = self.hidden_size // self.num_attention_heads
def read_from_env(self):
@@ -132,11 +137,9 @@ class ModelConfig:
if not hasattr(self, key.lower()):
if os.getenv(key, None):
value = eval(os.getenv(key))
llm_logger.info(
f"Get parameter `{key}` = {value} from environment.")
llm_logger.info(f"Get parameter `{key}` = {value} from environment.")
else:
llm_logger.info(
f"Parameter `{key}` will use default value {value}.")
llm_logger.info(f"Parameter `{key}` will use default value {value}.")
setattr(self, key.lower(), value)
reset_config_value("COMPRESSION_RATIO", 1.0)
@@ -153,8 +156,7 @@ class ModelConfig:
llm_logger.info("Model Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
class CacheConfig:
@@ -211,8 +213,7 @@ class CacheConfig:
self.enc_dec_block_num = enc_dec_block_num
self.cache_dtype = cache_dtype
if hasattr(model_cfg, "quantization_config"):
self.cache_dtype = model_cfg.quantization_config.get(
"kv_cache_quant_type", cache_dtype)
self.cache_dtype = model_cfg.quantization_config.get("kv_cache_quant_type", cache_dtype)
self.enable_chunked_prefill = enable_chunked_prefill
self.rdma_comm_ports = rdma_comm_ports
@@ -220,7 +221,7 @@ class CacheConfig:
self.pd_comm_port = pd_comm_port
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
self.rdma_comm_ports = rdma_comm_ports.split(',')
self.rdma_comm_ports = rdma_comm_ports.split(",")
if pd_comm_port is not None and isinstance(pd_comm_port, str):
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
@@ -236,41 +237,39 @@ class CacheConfig:
self.cache_queue_port = cache_queue_port
self.swap_space = swap_space
if (hasattr(self.model_cfg, "num_key_value_heads")
if (
hasattr(self.model_cfg, "num_key_value_heads")
and hasattr(self.model_cfg, "num_key_value_heads")
and self.model_cfg.num_key_value_heads is not None
and int(self.model_cfg.num_key_value_heads) > 0):
and int(self.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(self.model_cfg.num_key_value_heads)
else:
kv_num_head = self.model_cfg.num_attention_heads
self.model_cfg.kv_num_head = kv_num_head
# TODO check name
if "int4" in self.cache_dtype.lower(
) or "float4" in self.cache_dtype.lower():
if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower():
byte_size = 0.5
self.cache_dtype = "uint8"
elif "int8" in self.cache_dtype.lower(
) or "float8" in self.cache_dtype.lower():
elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower():
self.cache_dtype = "uint8"
byte_size = 1
else:
byte_size = 2
self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim *
byte_size)
self.bytes_per_block = int(self.each_token_cache_space *
self.block_size)
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size
)
self.bytes_per_block = int(self.each_token_cache_space * self.block_size)
self.bytes_per_layer_per_block = int(
self.block_size * self.model_cfg.kv_num_head *
self.model_cfg.head_dim // tensor_parallel_size * byte_size)
self.block_size * self.model_cfg.kv_num_head * self.model_cfg.head_dim // tensor_parallel_size * byte_size
)
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 /
self.bytes_per_block)
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self._verify_args()
def metrics_info(self):
@@ -279,12 +278,9 @@ class CacheConfig:
def _verify_args(self):
if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got "
f"{self.kv_cache_ratio}.")
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
def postprocess(self, num_total_tokens, number_of_tasks):
"""
@@ -293,27 +289,24 @@ class CacheConfig:
self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override
self.prefill_kvcache_block_num = int(self.total_block_num *
self.kv_cache_ratio)
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 +
self.dec_token_num) // self.block_size
block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num
llm_logger.info(
f"Doing profile, the total_block_num:{self.total_block_num}")
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks):
"""
reset gpu block number
"""
self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num = int(self.total_block_num *
self.kv_cache_ratio)
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
llm_logger.info(
(f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
def print(self):
"""
@@ -323,8 +316,7 @@ class CacheConfig:
llm_logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
class SpeculativeConfig:
@@ -340,14 +332,16 @@ class SpeculativeConfig:
benchmark_mode (bool): Whether to use benchmark mode.
"""
def __init__(self,
def __init__(
self,
method: Optional[str] = None,
num_speculative_tokens: Optional[int] = 1,
model: Optional[str] = None,
quantization: Optional[str] = "WINT8",
max_model_len: Optional[int] = None,
benchmark_mode: bool = False,
**kwargs):
**kwargs,
):
self.model_name_or_path = model
self.method = method
self.num_speculative_tokens = num_speculative_tokens
@@ -381,8 +375,7 @@ class SpeculativeConfig:
self.config_path = os.path.join(self.model_name_or_path, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(
open(self.config_path, 'r', encoding='utf-8'))
self.model_config = json.load(open(self.config_path, "r", encoding="utf-8"))
def reset(self):
"""
@@ -414,10 +407,7 @@ class SpeculativeConfig:
"""
Convert speculative_config to json string.
"""
return json.dumps({
key: value
for key, value in self.__dict__.items() if value is not None
})
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
def print(self):
"""
@@ -427,8 +417,7 @@ class SpeculativeConfig:
llm_logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
def __str__(self) -> str:
return self.to_json_string()
@@ -440,7 +429,7 @@ class GraphOptimizationConfig:
graph_opt_level: Optional[int] = 0,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
**kwargs,
):
"""
Graph Optimization Configuration class.
@@ -460,10 +449,7 @@ class GraphOptimizationConfig:
"""
Convert speculative_config to json string.
"""
return json.dumps({
key: value
for key, value in self.__dict__.items()
})
return json.dumps({key: value for key, value in self.__dict__.items()})
def __str__(self) -> str:
return self.to_json_string()
@@ -473,17 +459,25 @@ class GraphOptimizationConfig:
graph_opt_level: Optional[int] = None,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
**kwargs,
) -> None:
"""Check the legality of parameters passed in from the command line"""
if graph_opt_level is not None:
assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
assert graph_opt_level in [
0,
1,
2,
], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if use_cudagraph is not None:
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
if cudagraph_capture_sizes is not None:
assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
assert (
type(cudagraph_capture_sizes) is list
), "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert (
len(cudagraph_capture_sizes) > 0
), "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
for key, value in kwargs.items():
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
@@ -499,9 +493,12 @@ class GraphOptimizationConfig:
else:
# User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True:
raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.")
raise ValueError(
"Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously."
)
argument = self.use_cudagraph
class ParallelConfig:
"""
Configuration for parallelism.
@@ -544,8 +541,7 @@ class ParallelConfig:
llm_logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
@dataclass
@@ -560,6 +556,7 @@ class CommitConfig:
cuda_version: CUDA version string
compiler_version: CXX compiler version string
"""
fastdeploy_commit: str = ""
paddle_version: str = ""
paddle_commit: str = ""
@@ -573,7 +570,7 @@ class CommitConfig:
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
"""Internal method to load version info from file"""
try:
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"):
@@ -589,7 +586,7 @@ class CommitConfig:
except FileNotFoundError:
llm_logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e:
llm_logger.info(f"Warning: Could not read version file - {str(e)}")
llm_logger.info(f"Warning: Could not read version file - {e!s}")
def print(self):
"""
@@ -599,8 +596,7 @@ class CommitConfig:
llm_logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
class Config:
@@ -728,7 +724,6 @@ class Config:
self.disable_any_whitespace = disable_any_whitespace
self._str_to_list("innode_prefill_ports", int)
assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO
@@ -739,19 +734,16 @@ class Config:
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
assert (self.tensor_parallel_size == 1
and self.parallel_config.expert_parallel_size
>= 1) or (self.tensor_parallel_size >= 1
and self.parallel_config.expert_parallel_size
== 1), "TP and EP cannot be enabled at the same time"
assert (self.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
self.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
), "TP and EP cannot be enabled at the same time"
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node:
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, \
f"nnode: {nnode}, but got {self.nnode}"
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else:
self.worker_num_per_node = num_ranks
@@ -772,13 +764,14 @@ class Config:
"""
calculate some parameters
"""
assert self.device_ids.split(',').__len__() == self.worker_num_per_node, \
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
assert (
self.device_ids.split(",").__len__() == self.worker_num_per_node
), f"invalid CUDA_VISIBLE_DEVICES, should be equal to {self.worker_num_per_node}"
assert self.worker_num_per_node % self.tensor_parallel_size == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
self.local_device_ids = self.device_ids.split(
',')[:self.tensor_parallel_size]
assert (
self.worker_num_per_node % self.tensor_parallel_size == 0
), f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by worker_num_per_node: {self.worker_num_per_node}"
self.local_device_ids = self.device_ids.split(",")[: self.tensor_parallel_size]
self.host_ip = get_host_ip()
@@ -788,6 +781,7 @@ class Config:
self.is_master = False
import paddle
self.paddle_commit_id = paddle.version.commit
if self.max_num_batched_tokens is None:
@@ -799,10 +793,8 @@ class Config:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.cache_config.postprocess(self.max_num_batched_tokens,
self.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(
self.max_model_len // self.cache_config.block_size)
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
if self.guided_decoding_backend == "auto":
if self.enable_mm:
@@ -814,30 +806,26 @@ class Config:
"""
check the legality of config
"""
assert (
self.max_num_seqs <= 256
), "The parameter `max_num_seqs` is not allowed to exceed 256, " "but now it's {}.".format(
self.max_num_seqs)
assert (
is_port_available('0.0.0.0', self.engine_worker_queue_port)
assert self.max_num_seqs <= 256, (
"The parameter `max_num_seqs` is not allowed to exceed 256, " f"but now it's {self.max_num_seqs}."
)
assert is_port_available(
"0.0.0.0", self.engine_worker_queue_port
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
assert (
self.max_chips_per_node >= self.tensor_parallel_size > 0
), f"tensor_parallel_size: {self.tensor_parallel_size} should be between 1 and {self.max_chips_per_node}"
assert (self.nnode >= 1), f"nnode: {self.nnode} should no less than 1"
assert (
self.max_model_len >= 16
), f"max_model_len: {self.max_model_len} should be larger than 16"
assert (
self.max_num_seqs
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
assert (
self.max_num_batched_tokens >= self.max_num_seqs
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
assert self.nnode >= 1, f"nnode: {self.nnode} should no less than 1"
assert self.max_model_len >= 16, f"max_model_len: {self.max_model_len} should be larger than 16"
assert self.max_num_seqs >= 1, f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
assert self.max_num_batched_tokens >= self.max_num_seqs, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
assert (self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs), \
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" \
)
assert self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger"
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
)
assert (
self.max_num_partial_prefills >= 1
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
@@ -845,31 +833,38 @@ class Config:
assert (
self.max_long_partial_prefills >= 1
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
assert (self.max_long_partial_prefills <= self.max_num_partial_prefills), \
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " \
assert self.max_long_partial_prefills <= self.max_num_partial_prefills, (
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
)
if not self.cache_config.enable_chunked_prefill:
assert (
self.max_num_batched_tokens >= self.max_model_len
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}"
)
else:
assert (
self.max_num_batched_tokens >= self.cache_config.block_size
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to block_size: {self.cache_config.block_size}"
)
if self.max_num_partial_prefills > 1:
assert (self.cache_config.enable_chunked_prefill is True), \
"Chunked prefill must be enabled to set max_num_partial_prefills > 1"
assert (self.long_prefill_token_threshold < self.max_model_len), \
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"\
assert (
self.cache_config.enable_chunked_prefill is True
), "Chunked prefill must be enabled to set max_num_partial_prefills > 1"
assert self.long_prefill_token_threshold < self.max_model_len, (
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"
f" max_model_len: {self.max_model_len}"
)
if self.guided_decoding_backend is not None:
assert self.guided_decoding_backend in ["xgrammar", "XGrammar", "auto", "off"], \
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
assert self.guided_decoding_backend in [
"xgrammar",
"XGrammar",
"auto",
"off",
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
if self.guided_decoding_backend != "off":
# TODO: mm support guided_decoding
@@ -878,8 +873,7 @@ class Config:
# TODO: speculative decoding support guided_decoding
# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(
), "XPU currently do not support guided_decoding"
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
try:
import xgrammar # noqa
@@ -897,22 +891,22 @@ class Config:
Args:
file (str): the path of file to save config
"""
llm_logger.info(
"=================== Configuration Information ===============")
llm_logger.info("=================== Configuration Information ===============")
for k, v in self.__dict__.items():
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif (k == "cache_config" or
k == "model_config" or
k == "scheduler_config" or
k == "parallel_config" or
k == "commit_config"):
elif (
k == "cache_config"
or k == "model_config"
or k == "scheduler_config"
or k == "parallel_config"
or k == "commit_config"
):
v.print()
else:
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
llm_logger.info("=============================================================")
if file is not None:
f = open(file, "a")
now_time = datetime.now()
@@ -929,15 +923,14 @@ class Config:
if self.splitwise_role != "mixed":
disaggregate_info["role"] = self.splitwise_role
disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(
",")
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.engine_worker_queue_port,
"device_ids": self.local_device_ids
"device_ids": self.local_device_ids,
}
elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = {
@@ -957,13 +950,14 @@ class Config:
if hasattr(cls, key):
value = getattr(cls, key)
setattr(cls, value_name, value)
llm_logger.info(
f"Reset parameter {value_name} = {value} from configuration."
)
llm_logger.info(f"Reset parameter {value_name} = {value} from configuration.")
reset_value(self.cache_config, "block_size", "infer_model_block_size")
reset_value(self.model_config, "return_full_hidden_states",
"return_full_hidden_states")
reset_value(
self.model_config,
"return_full_hidden_states",
"return_full_hidden_states",
)
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def _check_master(self):

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import copy
@@ -40,18 +41,21 @@ from fastdeploy.engine.expert_service import start_expert_service
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (EngineCacheQueue, EngineWorkerQueue,
IPCSignal, ZmqClient)
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import (TokenProcessor,
WarmUpTokenProcessor)
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
class LLMEngine(object):
class LLMEngine:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -94,30 +98,28 @@ class LLMEngine(object):
self.running = True
self.scheduler = cfg.scheduler_config.scheduler()
self.input_processor = InputPreprocessor(cfg.tokenizer,
self.input_processor = InputPreprocessor(
cfg.tokenizer,
cfg.reasoning_parser,
cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs,
cfg.enable_mm)
cfg.enable_mm,
)
self.start_queue_service()
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg,
cfg.tensor_parallel_size,
cfg.splitwise_role)
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role)
os.environ['INFERENCE_MSG_QUEUE_ID'] = str(
self.cfg.engine_worker_queue_port)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
self.split_connector = SplitwiseConnector(cfg, self.scheduler,
self.engine_worker_queue,
self.resource_manager)
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
self.token_processor = TokenProcessor(
cfg=self.cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector)
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.is_started = False
@@ -129,11 +131,13 @@ class LLMEngine(object):
else:
self.do_profile = 0
self.partial_chunked_tokens = [0] * (
self.cfg.max_num_partial_prefills + 1)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.partial_chunked_tokens[idx] = max(1, self.partial_chunked_tokens[idx])
self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -168,8 +172,8 @@ class LLMEngine(object):
time.sleep(3)
if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching \
or self.cfg.splitwise_role != "mixed"):
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
):
device_ids = self.cfg.device_ids.split(",")
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
@@ -177,16 +181,15 @@ class LLMEngine(object):
device_ids=device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix)
pid_suffix=self.ipc_signal_suffix,
)
self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...")
time.sleep(5)
self.worker_init_status = dict()
if not self.check_worker_initialize_status():
console_logger.error(
"Failed to launch worker processes, check log/workerlog.* for more details."
)
console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
return False
# Start warmup if enabled
@@ -199,17 +202,16 @@ class LLMEngine(object):
self.token_processor.tasks_queue = self.engine_worker_queue
self.insert_task_to_worker_thread = threading.Thread(
target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
if self.api_server_pid is not None:
self.insert_task_to_scheduler_thread = threading.Thread(
target=self._insert_zmq_task_to_scheduler, daemon=True)
target=self._insert_zmq_task_to_scheduler, daemon=True
)
self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread(
target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()
# Start TokenProcessor thread
@@ -223,8 +225,7 @@ class LLMEngine(object):
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
self.splitwise_receive_thread = threading.Thread(
target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
@@ -240,20 +241,28 @@ class LLMEngine(object):
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(target=start_expert_service,
args=(self.cfg,
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix)))
llm_logger.info(f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" \
+ " data parallel id {}".format(i))
self.ipc_signal_suffix,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
console_logger.info(
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True
def _zmq_send_generated_tokens(self):
@@ -271,8 +280,7 @@ class LLMEngine(object):
self.zmq_server.send_multipart(request_id, contents)
except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format(
e, str(traceback.format_exc())))
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def _get_generated_result(self):
"""
@@ -296,8 +304,7 @@ class LLMEngine(object):
time.sleep(0.001)
continue
if self.exist_prefill_task_signal.value[0] > 0:
if self.cfg.splitwise_role == "mixed" or \
self.split_connector.has_splitwise_tasks():
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
@@ -309,17 +316,17 @@ class LLMEngine(object):
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch)
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(
),
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.
enc_dec_block_num,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch)
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
@@ -328,16 +335,14 @@ class LLMEngine(object):
current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(
tasks, current_id)
self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(
e, str(traceback.format_exc()))
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
def _insert_zmq_task_to_scheduler(self):
@@ -353,8 +358,7 @@ class LLMEngine(object):
else:
err, data = self.zmq_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error(
"Engine stops inserting zmq task into scheduler")
llm_logger.error("Engine stops inserting zmq task into scheduler")
break
request, insert_task = None, []
@@ -363,13 +367,11 @@ class LLMEngine(object):
request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
llm_logger.debug(f"Receive request: {request}")
err_msg = None
if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(
request)
request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None:
llm_logger.error(err_msg)
@@ -394,17 +396,20 @@ class LLMEngine(object):
main_process_metrics.num_requests_waiting.inc(1)
continue
error_result = RequestOutput(request_id=request_id,
error_result = RequestOutput(
request_id=request_id,
finished=True,
error_code=500,
error_msg=failed)
error_msg=failed,
)
# Since the request is not in scheduler
# Send result by zmq directly
self.zmq_server.send_multipart(request_id, error_result)
except Exception as e:
llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}")
f"traceback={traceback.format_exc()}"
)
def add_requests(self, task, sampling_params=None, **kwargs):
"""
@@ -428,23 +433,25 @@ class LLMEngine(object):
enable_thinking = None
if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request(
request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request.prompt_token_ids_len = len(request.prompt_token_ids)
input_ids_len = request.prompt_token_ids_len
request.set(
"max_tokens",
min(self.cfg.max_model_len - input_ids_len,
request.get("max_tokens")))
min(
self.cfg.max_model_len - input_ids_len,
request.get("max_tokens"),
),
)
if request.get("reasoning_max_tokens") is None:
default_reasoning_max_tokens = max(
int(request.get("max_tokens") * 0.8), 1)
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
min_tokens = request.get("min_tokens")
if input_ids_len + min_tokens >= self.cfg.max_model_len:
error_msg = (
f"Input text is too long, length of prompt token({input_ids_len}) "
f"+ min_dec_len ({min_tokens}) >= max_model_len ")
f"+ min_dec_len ({min_tokens}) >= max_model_len "
)
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
@@ -456,16 +463,14 @@ class LLMEngine(object):
raise EngineError(error_msg, error_code=400)
if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(
request)
request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None:
llm_logger.error(err_msg)
raise EngineError(err_msg, error_code=400)
request.preprocess_end_time = time.time()
self.scheduler.put_requests([request])
llm_logger.info(
f"Cache task with request_id ({request.get('request_id')})")
llm_logger.info(f"Cache task with request_id ({request.get('request_id')})")
llm_logger.debug(f"cache task: {request}")
def warmup(self):
@@ -486,25 +491,19 @@ class LLMEngine(object):
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
llm_logger.info(
f"Resource available, processing task {task.request_id}"
)
llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
llm_logger.debug(
f"Still waiting for resources {task.request_id}"
)
llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks(
)
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
@@ -515,7 +514,7 @@ class LLMEngine(object):
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], 'finished'):
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -527,25 +526,19 @@ class LLMEngine(object):
else:
if len(self.waiting_requests):
llm_logger.info(
f"Waiting for resource for task {tasks[0].request_id}"
)
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(
new_waiting)
llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
@@ -572,13 +565,10 @@ class LLMEngine(object):
if current_request_size[idx] <= 0:
chunk_request_num -= 1
if not self.cfg.cache_config.enable_chunked_prefill or len(
requests) == 0:
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
current_request_size = [
request.prompt_token_ids_len for request in requests
]
current_request_size = [request.prompt_token_ids_len for request in requests]
requests_chunk = [[] for _ in range(len(requests))]
chunk_request_num = len(current_request_size)
while chunk_request_num >= 1:
@@ -588,25 +578,25 @@ class LLMEngine(object):
continue
chunk_size = min(
current_request_size[idx],
self.partial_chunked_tokens[chunk_request_num])
self.partial_chunked_tokens[chunk_request_num],
)
update_tokens(idx, chunk_size)
while remain_batched_tokens >= self.cfg.cache_config.block_size:
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
waiting_requests = [
input_lens for input_lens in current_request_size
if input_lens > 0
]
waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
if len(waiting_requests) == 0:
break
available_tokens = remain_batched_tokens // self.cfg.cache_config.block_size * \
self.cfg.cache_config.block_size
available_tokens = (
remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
)
append_idx = current_request_size.index(min(waiting_requests))
chunk_size = min(
current_request_size[append_idx],
self.partial_chunked_tokens[chunk_request_num],
available_tokens)
available_tokens,
)
update_tokens(append_idx, chunk_size, update_chunk=True)
for idx in range(len(requests)):
@@ -616,8 +606,7 @@ class LLMEngine(object):
"""
update each multimodal request's chunk size info
"""
if not self.cfg.cache_config.enable_chunked_prefill or len(
requests) == 0:
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
for request in requests:
@@ -628,12 +617,9 @@ class LLMEngine(object):
inputs["grid_thw"] = np.array([], dtype="int64")
inputs["images"] = np.array([], dtype="uint8")
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"],
dtype="int32")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
image_mask = input_ids == self.data_processor.image_patch_id
image_token_sum = paddle.full(shape=[len(input_ids) + 1],
fill_value=0,
dtype="int32")
image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"))
grid_thw = []
for one in inputs["grid_thw"]:
@@ -644,45 +630,46 @@ class LLMEngine(object):
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
chunk_image_num, chunk_seq_len = get_mm_split_fuse(
input_ids, image_type_ids, image_token_sum, grid_thw,
self.data_processor.image_patch_id, len(grid_thw), 0,
len(input_ids), 0, self.partial_chunked_tokens[1], 2048)
input_ids,
image_type_ids,
image_token_sum,
grid_thw,
self.data_processor.image_patch_id,
len(grid_thw),
0,
len(input_ids),
0,
self.partial_chunked_tokens[1],
2048,
)
grid_thw = grid_thw.numpy().reshape([-1, 3])
num_chunks = len(chunk_image_num)
chunks_info = []
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
for idx in range(num_chunks):
chunk_input_ids = inputs["input_ids"][
input_ids_st:input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][
input_ids_st:input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st:grid_thw_st +
chunk_image_num[idx], 0])
chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
chunk_image_type_ids = inputs["image_type_ids"][
image_type_ids_st:image_type_ids_st + actual_image_num]
chunk_grid_thw = grid_thw[grid_thw_st:grid_thw_st +
chunk_image_num[idx]]
image_type_ids_st : image_type_ids_st + actual_image_num
]
chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
chunk_images = inputs["images"][patch_st:patch_st +
chunk_patch_num]
chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
chunks_info.append({
"input_ids":
chunk_input_ids,
"token_type_ids":
chunk_token_type_ids,
"image_type_ids":
chunk_image_type_ids
if chunk_image_type_ids.shape[0] else None,
"grid_thw":
chunk_grid_thw if chunk_grid_thw.shape[0] else None,
"images":
chunk_images if chunk_images.shape[0] else None,
"position_ids":
None
})
chunks_info.append(
{
"input_ids": chunk_input_ids,
"token_type_ids": chunk_token_type_ids,
"image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
"grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
"images": (chunk_images if chunk_images.shape[0] else None),
"position_ids": None,
}
)
input_ids_st += chunk_seq_len[idx]
image_type_ids_st += actual_image_num
@@ -704,18 +691,14 @@ class LLMEngine(object):
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in [
"mtp"
] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(
task.outputs.draft_token_ids)
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[
task.request_id]
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
@@ -723,8 +706,7 @@ class LLMEngine(object):
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks(
(current_tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
@@ -737,9 +719,7 @@ class LLMEngine(object):
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
@@ -763,8 +743,7 @@ class LLMEngine(object):
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[
i].prompt_token_ids_len
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode:
@@ -776,8 +755,7 @@ class LLMEngine(object):
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks(
(tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
return True
@@ -793,8 +771,7 @@ class LLMEngine(object):
"""
judge if all tasks are finished
"""
return np.sum(self.resource_manager.stop_flags) == len(
self.resource_manager.stop_flags)
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
def _set_warmup_token_processor(self):
"""
@@ -824,8 +801,7 @@ class LLMEngine(object):
judge if all worker processes are ready
"""
if np.sum(self.worker_ready_signal.value
) == self.cfg.worker_num_per_node:
if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node:
return True
return False
@@ -835,30 +811,33 @@ class LLMEngine(object):
"""
# worker_ready_signatensor_parallel_size
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_ready_signal = IPCSignal(name="worker_ready_signal",
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=worker_ready_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
# exist_task_signal 用于各worker进程感知是否有新Task需要处理
exist_task_signal_data = np.zeros(
[self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_task_signal = IPCSignal(name="exist_task_signal",
exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
exist_swapped_task_signal_data = np.zeros(
[self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
@@ -867,17 +846,18 @@ class LLMEngine(object):
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node],
dtype=np.int32)
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
if self.do_profile:
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
@@ -886,7 +866,8 @@ class LLMEngine(object):
array=get_profile_block_num,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
@@ -894,7 +875,8 @@ class LLMEngine(object):
array=model_weights_status,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True)
create=True,
)
def _exit_sub_services(self):
"""
@@ -903,8 +885,7 @@ class LLMEngine(object):
self.running = False
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear(
)
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
@@ -943,7 +924,7 @@ class LLMEngine(object):
"TRAINER_INSTANCES_NUM": 1,
"TRAINER_INSTANCES": "0.0.0.0",
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(',')),
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(",")),
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
@@ -951,24 +932,22 @@ class LLMEngine(object):
"FLAGS_hardamard_moe_block_size": 128,
}
# environment variables needed by Dy2St
variables.update({
"SOT_LOG_LEVEL":
os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH":
os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK":
os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"FLAGS_specialize_device_in_dy2st":
os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc":
os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache":
os.getenv("FLAGS_pir_interpreter_record_stream_for_gc_cache",
default="1"),
"FLAGS_parameters_persistent_mode_in_dy2st":
os.getenv("FLAGS_parameters_persistent_mode_in_dy2st",
default="1"),
})
variables.update(
{
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
"FLAGS_pir_interpreter_record_stream_for_gc_cache",
default="1",
),
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
),
}
)
if self.cfg.splitwise_role != "mixed":
variables["FLAGS_use_pd_disaggregation"] = 1
@@ -994,8 +973,7 @@ class LLMEngine(object):
current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.split(current_file_path)[0]
# TODO
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT",
"0") == 1 else "-u"
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == 1 else "-u"
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch"
pd_cmd = pd_cmd + f" --log_dir {log_dir}"
@@ -1004,7 +982,7 @@ class LLMEngine(object):
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, 'sp_model')
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
)
@@ -1012,10 +990,10 @@ class LLMEngine(object):
f" --devices {self.cfg.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
f" --model_name_or_path {str(self.cfg.model_name_or_path)}"
f" --model_name_or_path {self.cfg.model_name_or_path!s}"
f" --device_ids {self.cfg.device_ids}"
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}"
f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}"
f" --pod_ip {self.cfg.master_ip}"
f" --total_block_num {self.cfg.cache_config.total_block_num}"
f" --block_size {self.cfg.cache_config.block_size}"
@@ -1036,16 +1014,13 @@ class LLMEngine(object):
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
f" --load_strategy {self.cfg.model_config.load_strategy}"
)
worker_append_flag = {
"enable_expert_parallel":
self.cfg.parallel_config.enable_expert_parallel,
"enable_prefix_caching":
self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill":
self.cfg.cache_config.enable_chunked_prefill,
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.disable_any_whitespace,
@@ -1059,11 +1034,11 @@ class LLMEngine(object):
if self.cfg.nnode > 1:
pd_cmd = pd_cmd + (
f" --master {self.cfg.dist_init_addr}"
f" --nnodes {str(self.cfg.nnode)}"
f" --rank {str(self.cfg.node_rank)}"
f" --nnodes {self.cfg.nnode!s}"
f" --rank {self.cfg.node_rank!s}"
)
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
llm_logger.info("Launch worker service command: {}".format(pd_cmd))
llm_logger.info(f"Launch worker service command: {pd_cmd}")
p = subprocess.Popen(
pd_cmd,
stdout=subprocess.PIPE,
@@ -1111,8 +1086,7 @@ class LLMEngine(object):
try:
req_id = self._format_and_add_data(prompts)
except Exception as e:
llm_logger.error(
f"Error happend while adding request, details={e}")
llm_logger.error(f"Error happend while adding request, details={e}")
raise EngineError(str(e), error_code=400)
# 获取当前请求的结果
@@ -1151,8 +1125,7 @@ class LLMEngine(object):
if num_gpu_blocks < 0:
num_gpu_blocks = self.get_profile_block_num_signal.value[i]
else:
num_gpu_blocks = min(
num_gpu_blocks, self.get_profile_block_num_signal.value[i])
num_gpu_blocks = min(num_gpu_blocks, self.get_profile_block_num_signal.value[i])
self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config)
@@ -1164,15 +1137,16 @@ class LLMEngine(object):
device_ids=device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix)
pid_suffix=self.ipc_signal_suffix,
)
def check_health(self, time_interval_threashold=30):
"""
Check the health of the model server by checking whether all workers are alive.
"""
if self.worker_healthy_live_signal.value[0]:
elapsed_time = time.time() - \
self.worker_healthy_live_signal.value[0]
elapsed_time = time.time() - self.worker_healthy_live_signal.value[0]
if elapsed_time > time_interval_threashold:
return False, "Worker Service Not Healthy"
@@ -1185,38 +1159,31 @@ class LLMEngine(object):
def detect_thread():
for line in self.worker_proc.stdout:
line = line.decode('utf-8', errors='ignore')
line = line.decode("utf-8", errors="ignore")
if self.worker_init_status.get("finished", False):
break
if match := re.search(
r'Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)',
line):
self.worker_init_status["weight_loadding"] = eval(
match.group(1)) * 1.0 / 100
elif (match := re.search(r'Start load layer (\d+)',
line)) or (match := re.search(
r'set state for layer (\d+)',
line)):
progress = eval(match.group(
1)) * 1.0 / self.cfg.model_config.num_layers
r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)",
line,
):
self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100
elif (match := re.search(r"Start load layer (\d+)", line)) or (
match := re.search(r"set state for layer (\d+)", line)
):
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_layers
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status[
"layer_loadding"] == self.cfg.model_config.num_layers - 1:
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1:
self.worker_init_status["finished"] = True
self.checking_worker_status_thread = threading.Thread(
target=detect_thread, daemon=True)
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start()
# display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar:
progress = 0
while progress < 100:
progress = int(
self.worker_init_status.get("weight_loadding", 0) * 100)
if self.worker_init_status.get(
"layer_loadding",
0) > 0 or self._worker_processes_ready():
progress = int(self.worker_init_status.get("weight_loadding", 0) * 100)
if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
pbar.refresh()
@@ -1228,8 +1195,7 @@ class LLMEngine(object):
with tqdm(total=100, desc="Loading Layers") as pbar:
progress = 0
while progress < 100:
progress = int(
self.worker_init_status.get("layer_loadding", 0) * 100)
progress = int(self.worker_init_status.get("layer_loadding", 0) * 100)
if self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
@@ -1256,19 +1222,21 @@ class LLMEngine(object):
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.cache_queue_port),
authkey=b'cache_queue_service',
address=(
self.cfg.master_ip,
self.cfg.cache_config.cache_queue_port,
),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.
data_parallel_size)
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
@@ -1276,5 +1244,8 @@ class LLMEngine(object):
num_client=self.cfg.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id= min(self.cfg.worker_num_per_node * self.cfg.node_rank,
self.cfg.parallel_config.data_parallel_size - 1))
local_data_parallel_id=min(
self.cfg.worker_num_per_node * self.cfg.node_rank,
self.cfg.parallel_config.data_parallel_size - 1,
),
)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
@@ -32,7 +33,7 @@ from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
class ExpertService(object):
class ExpertService:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -51,17 +52,14 @@ class ExpertService(object):
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[
start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(
",")[start_pos:end_pos]
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid(
f"{self.scheduler.infer.nodeid}_{str(local_data_parallel_id)}")
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -73,33 +71,41 @@ class ExpertService(object):
num_client=cfg.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id,
)
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, \
cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id)
self.resource_manager = ResourceManager(
cfg.max_num_seqs,
cfg,
cfg.tensor_parallel_size,
cfg.splitwise_role,
local_data_parallel_id,
)
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = int(
self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
self.cfg.cache_config.pd_comm_port[0] = int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
else:
self.cfg.cache_config.pd_comm_port = [
self.cfg.cache_config.pd_comm_port[local_data_parallel_id]
]
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector(self.cfg, self.scheduler,
self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager)
self.resource_manager,
)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector)
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (
self.cfg.max_num_partial_prefills + 1)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -120,17 +126,15 @@ class ExpertService(object):
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}"
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
self.insert_task_to_worker_thread = threading.Thread(
target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run()
@@ -144,9 +148,7 @@ class ExpertService(object):
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True
def _insert_task_to_worker(self):
@@ -169,17 +171,17 @@ class ExpertService(object):
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch)
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(
),
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.
enc_dec_block_num,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch)
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
@@ -187,8 +189,7 @@ class ExpertService(object):
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(
tasks, current_id)
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
@@ -197,8 +198,7 @@ class ExpertService(object):
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(
e, str(traceback.format_exc()))
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
def split_mode_get_tasks(self):
@@ -212,15 +212,13 @@ class ExpertService(object):
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks(
)
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
@@ -231,7 +229,7 @@ class ExpertService(object):
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], 'finished'):
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
@@ -246,7 +244,8 @@ class ExpertService(object):
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task])
@@ -274,8 +273,7 @@ class ExpertService(object):
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[
task.request_id]
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
@@ -285,8 +283,7 @@ class ExpertService(object):
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks(
(current_tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
@@ -299,9 +296,7 @@ class ExpertService(object):
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
@@ -325,8 +320,7 @@ class ExpertService(object):
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[
i].prompt_token_ids_len
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
@@ -338,8 +332,7 @@ class ExpertService(object):
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks(
(tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _exit_sub_services(self):
@@ -348,8 +341,7 @@ class ExpertService(object):
"""
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear(
)
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
from dataclasses import dataclass
from typing import list
@@ -25,6 +26,7 @@ class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
"""
# number of tokens in a block
block_size: int
# the memory size used by each block in bytes.
@@ -37,10 +39,9 @@ class KVCacheSpec:
"""
# check list
assert all(
(spec.block_size == specs[0].block_size
and spec.block_memory_used == specs[0].block_memory_used)
for spec in specs[1:]), (
"All layers in the model must share the same block_size.")
(spec.block_size == specs[0].block_size and spec.block_memory_used == specs[0].block_memory_used)
for spec in specs[1:]
), "All layers in the model must share the same block_size."
return copy.deepcopy(specs[0])
@@ -48,6 +49,7 @@ class KVCacheSpec:
@dataclass
class AttentionSpec(KVCacheSpec):
""" """
num_kv_heads: int
head_size: int
dtype: str

View File

@@ -29,8 +29,8 @@ from fastdeploy.worker.output import LogprobsLists
@dataclass
class Request:
def __init__(self,
def __init__(
self,
request_id: str,
prompt: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[list[int]],
@@ -56,7 +56,8 @@ class Request:
structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict()) -> None:
trace_carrier: dict = dict(),
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
@@ -98,7 +99,8 @@ class Request:
def from_dict(cls, d: dict):
data_processor_logger.debug(f"{d}")
sampling_params = SamplingParams.from_dict(d)
return cls(request_id=d["request_id"],
return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
prompt_token_ids=d.get("prompt_token_ids"),
prompt_token_ids_len=d.get("prompt_token_ids_len"),
@@ -123,7 +125,8 @@ class Request:
structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}))
trace_carrier=d.get("trace_carrier", {}),
)
def to_dict(self) -> dict:
"""convert Request into a serializable dict"""
@@ -146,11 +149,15 @@ class Request:
"disaggregate_info": self.disaggregate_info,
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking,
"trace_carrier": self.trace_carrier
"trace_carrier": self.trace_carrier,
}
add_params = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
"structural_tag", "guided_json_object"
"guided_json",
"guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
"guided_json_object",
]
for param in add_params:
if getattr(self, param, None) is not None:
@@ -174,11 +181,13 @@ class Request:
setattr(self, key, value)
def __repr__(self) -> str:
return (f"Request(request_id={self.request_id}, "
return (
f"Request(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"draft_token_ids={self.draft_token_ids}, "
f"sampling_params={self.sampling_params})")
f"sampling_params={self.sampling_params})"
)
@dataclass(slots=True)
@@ -212,27 +221,28 @@ class CompletionOutput:
"top_logprobs": self.top_logprobs,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
"reasoning_content": self.reasoning_content
"reasoning_content": self.reasoning_content,
}
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'CompletionOutput':
def from_dict(cls, req_dict: dict[str, Any]) -> CompletionOutput:
"""Create instance from dict arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
})
}
)
def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, "
return (
f"CompletionOutput(index={self.index}, "
f"send_idx={self.send_idx}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}")
f"reasoning_content={self.reasoning_content!r}"
)
@dataclass(slots=True)
@@ -252,6 +262,7 @@ class RequestMetrics:
request_start_time: Time to accept the request
"""
arrival_time: float
inference_start_time: Optional[float] = None
first_token_time: Optional[float] = None
@@ -273,19 +284,18 @@ class RequestMetrics:
"preprocess_cost_time": self.preprocess_cost_time,
"model_forward_time": self.model_forward_time,
"model_execute_time": self.model_execute_time,
"request_start_time": self.request_start_time
"request_start_time": self.request_start_time,
}
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'RequestMetrics':
def from_dict(cls, req_dict: dict[str, Any]) -> RequestMetrics:
"""Create instance from dict arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
})
}
)
class RequestOutput:
@@ -333,13 +343,12 @@ class RequestOutput:
self.error_code = error_code
self.error_msg = error_msg
if prompt_token_ids is None:
self.prompt_token_ids = []
elif isinstance(self.prompt_token_ids, np.ndarray):
self.prompt_token_ids = self.prompt_token_ids.tolist()
def add(self, next_output: "RequestOutput") -> None:
def add(self, next_output: RequestOutput) -> None:
"""Merge RequestOutput into this one"""
self.prompt = next_output.prompt
@@ -348,19 +357,19 @@ class RequestOutput:
self.outputs.index = next_output.outputs.index
self.outputs.token_ids.extend(next_output.outputs.token_ids)
if next_output.metrics.arrival_time is not None and self.metrics.inference_start_time is not None:
self.metrics.model_forward_time = next_output.metrics.arrival_time - \
self.metrics.inference_start_time
self.metrics.model_forward_time = next_output.metrics.arrival_time - self.metrics.inference_start_time
if next_output.metrics.arrival_time is not None and self.metrics.arrival_time is not None:
self.metrics.model_execute_time = next_output.metrics.arrival_time - \
self.metrics.arrival_time
self.metrics.model_execute_time = next_output.metrics.arrival_time - self.metrics.arrival_time
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, "
f"metrics={self.metrics}, "
f"num_cached_tokens={self.num_cached_tokens})")
f"num_cached_tokens={self.num_cached_tokens})"
)
@classmethod
def from_dict(cls, d: dict):
@@ -376,10 +385,8 @@ class RequestOutput:
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"outputs":
None if self.outputs is None else self.outputs.to_dict(),
"metrics":
None if self.metrics is None else self.metrics.to_dict(),
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,
"num_cached_tokens": self.num_cached_tokens,
"error_code": self.error_code,

View File

@@ -25,17 +25,19 @@ from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
class ResourceManager(object):
class ResourceManager:
"""
record and allocate resources for the engine
"""
def __init__(self,
def __init__(
self,
max_num_seqs,
config,
tensor_parallel_size,
splitwise_role,
local_data_parallel_id=0):
local_data_parallel_id=0,
):
"""
Args:
cfg (Config): config object containing parameters for the engine
@@ -51,9 +53,7 @@ class ResourceManager(object):
self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size,
splitwise_role,
local_data_parallel_id)
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
self.tasks_list = [None] * max_num_seqs
self.req_dict = dict()
# current batch status of the engine
@@ -77,8 +77,7 @@ class ResourceManager(object):
Returns:
int: block number
"""
block_num = (input_token_num + self.cfg.block_size - 1 +
self.cfg.dec_token_num) // self.cfg.block_size
block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
return block_num
def get_encoder_block_number(self, input_token_num):
@@ -91,8 +90,7 @@ class ResourceManager(object):
Returns:
int: encoder block number
"""
enc_block_num = (input_token_num + self.cfg.block_size -
1) // self.cfg.block_size
enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
return enc_block_num
def get_decoder_block_number(self):
@@ -102,8 +100,7 @@ class ResourceManager(object):
Returns:
int: decoder block number
"""
return (self.cfg.dec_token_num + self.cfg.block_size -
1) // self.cfg.block_size
return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
def total_block_number(self):
"""
@@ -132,13 +129,12 @@ class ResourceManager(object):
elif required_type == "decoder":
block_num = self.get_decoder_block_number()
else:
raise ValueError('unknown required type')
raise ValueError("unknown required type")
block_list = list()
current_block_num = self.available_block_num()
if block_num > current_block_num:
llm_logger.error("block_num:{0} > free_list len:{1}".format(
block_num, current_block_num))
llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}")
return block_list
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
@@ -172,10 +168,8 @@ class ResourceManager(object):
ori_number = self.available_block_num()
self.cache_manager.recycle_gpu_blocks(block_tables)
cur_number = self.available_block_num()
main_process_metrics.gpu_cache_usage_perc.set(
self.get_gpu_cache_usage_perc())
llm_logger.info(
f"recycle {req_id} {cur_number - ori_number} blocks.")
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
def available_batch(self):
"""
@@ -238,8 +232,7 @@ class ResourceManager(object):
can_insert = False
while allocated_position + 1 <= self.max_num_seqs:
if sum(self.stop_flags[allocated_position:allocated_position +
1]) == 1:
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
can_insert = True
break
allocated_position += 1
@@ -249,72 +242,63 @@ class ResourceManager(object):
task = tasks[processing_task_index]
if task.get("seed") is None:
task.set("seed",
random.randint(0, 9223372036854775807))
task.set("seed", random.randint(0, 9223372036854775807))
task.idx = allocated_position
if self.enable_prefix_cache:
cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task, self.cfg.block_size, self.cfg.dec_token_num)
task,
self.cfg.block_size,
self.cfg.dec_token_num,
)
if unique_block_ids is None:
llm_logger.warning(
"req_id: {0} not enough blocks available".
format(task["req_id"]))
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
return
cached_len = self._record_request_cache_info(
task, common_block_ids, unique_block_ids, hit_info)
task.cache_prepare_time = time.time(
) - cache_prepare_time
task, common_block_ids, unique_block_ids, hit_info
)
task.cache_prepare_time = time.time() - cache_prepare_time
if task.disaggregate_info is not None:
if task.disaggregate_info['role'] == "prefill":
self.req_dict[
task.request_id] = allocated_position
task.disaggregate_info[
'block_tables'] = task.block_tables
if task.disaggregate_info["role"] == "prefill":
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.block_tables
self._delete_cached_data(task, cached_len)
elif task.disaggregate_info['role'] == "decode":
self.req_dict[
task.request_id] = allocated_position
task.disaggregate_info[
'block_tables'] = task.need_block_tables
elif task.disaggregate_info["role"] == "decode":
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.need_block_tables
else:
self._delete_cached_data(task, cached_len)
else:
block_tables = self._get_block_tables(
task.prompt_token_ids_len)
block_tables = self._get_block_tables(task.prompt_token_ids_len)
if not block_tables:
llm_logger.error(
"req_id: {0} block_tables is empty".format(
task.request_id))
llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
continue
else:
task.block_tables = block_tables
task.need_block_tables = task.block_tables
if task.disaggregate_info is not None:
task.disaggregate_info[
'block_tables'] = block_tables
if task.disaggregate_info['role'] == "prefill":
self.req_dict[
task.request_id] = allocated_position
elif task.disaggregate_info['role'] == "decode":
self.req_dict[
task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = block_tables
if task.disaggregate_info["role"] == "prefill":
self.req_dict[task.request_id] = allocated_position
elif task.disaggregate_info["role"] == "decode":
self.req_dict[task.request_id] = allocated_position
processed_tasks.append(task)
self.stop_flags[allocated_position] = False
task.inference_start_time = time.time()
task.inference_time_cost = -1.0
task.tokens_all_num = int(0)
task.tokens_all_num = 0
self.tasks_list[allocated_position] = task
llm_logger.info(
f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}")
f"length of prompt token: {task.prompt_token_ids_len}"
)
allocated_position += 1
processing_task_index += 1
@@ -325,11 +309,10 @@ class ResourceManager(object):
break
llm_logger.info(
f"Number of allocated requests: {len(tasks)}, number of "
f"running requests in worker: {self.real_bsz}")
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
)
llm_logger.info(f"{self.info()}")
main_process_metrics.gpu_cache_usage_perc.set(
self.get_gpu_cache_usage_perc())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
return processed_tasks
@@ -345,19 +328,15 @@ class ResourceManager(object):
task.seq_lens_decoder = cached_len
task.prompt_token_ids_len = len(task.prompt_token_ids)
def _record_request_cache_info(self, task, common_block_ids,
unique_block_ids, hit_info):
def _record_request_cache_info(self, task, common_block_ids, unique_block_ids, hit_info):
"""
Record the cache information for a given task and its corresponding block IDs.
"""
cache_block_num = len(common_block_ids)
no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size \
- cache_block_num)
no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size - cache_block_num)
task.num_cached_tokens = cache_block_num * self.cfg.block_size
task.gpu_cache_token_num = hit_info[
"gpu_cache_blocks"] * self.cfg.block_size
task.cpu_cache_token_num = hit_info[
"cpu_cache_blocks"] * self.cfg.block_size
task.gpu_cache_token_num = hit_info["gpu_cache_blocks"] * self.cfg.block_size
task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
task.cache_info = (cache_block_num, no_cache_block_num)
cached_len = len(common_block_ids) * self.cfg.block_size
@@ -374,9 +353,11 @@ class ResourceManager(object):
Returns:
str: resource manager info
"""
info = f"ResourceManager info, " \
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
info = (
f"ResourceManager info, "
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, "
f"available_block_num: {self.available_block_num()}, available_batch: {self.available_batch()}"
)
return info
def get_gpu_cache_usage_perc(self):

View File

@@ -94,18 +94,18 @@ class SamplingParams:
bad_words: Optional[List[str]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
})
}
)
@classmethod
def from_optional(cls,
def from_optional(
cls,
n,
best_of,
presence_penalty,
@@ -121,16 +121,15 @@ class SamplingParams:
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None) -> "SamplingParams":
bad_words=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(n=1 if n is None else n,
return cls(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty
if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty
if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
presence_penalty=(presence_penalty if presence_penalty is not None else 0.0),
frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0),
repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0),
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
@@ -141,7 +140,8 @@ class SamplingParams:
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words)
bad_words=bad_words,
)
def __post_init__(self):
if self.seed is None:
@@ -152,60 +152,44 @@ class SamplingParams:
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(
f"n must be an int, but is of type {type(self.n)}")
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if self.presence_penalty is not None and (
not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
if self.frequency_penalty is not None and (
not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.presence_penalty is not None and (not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.")
if self.frequency_penalty is not None and (not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.")
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
raise ValueError("repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}.")
if self.temperature is not None and self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
f"got {self.top_k}.")
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(
f"top_k must be an integer, got {type(self.top_k).__name__}")
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
raise ValueError(
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
raise ValueError(f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
raise ValueError(f"min_tokens must be greater than or equal to 0, " f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20:
raise ValueError(
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got "
f"{self.seed}.")
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""
@@ -218,6 +202,7 @@ class SamplingParams:
@dataclass
class BeamSearchParams:
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False

View File

@@ -14,19 +14,25 @@
# limitations under the License.
"""
import uvicorn
import json
import uvicorn
from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse
from fastdeploy.utils import FlexibleArgumentParser, api_server_logger, is_port_available
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
is_port_available,
)
app = FastAPI()
llm_engine = None
def init_app(args):
"""
init LLMEngine
@@ -39,7 +45,7 @@ def init_app(args):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return False
api_server_logger.info(f"FastDeploy LLM engine initialized!")
api_server_logger.info("FastDeploy LLM engine initialized!")
return True
@@ -48,6 +54,7 @@ async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: dict):
"""
@@ -64,7 +71,7 @@ async def generate(request: dict):
output = result
except Exception as e:
# 记录完整的异常堆栈信息
api_server_logger.error(f"Error during generation: {str(e)}", exc_info=True)
api_server_logger.error(f"Error during generation: {e!s}", exc_info=True)
# 返回结构化的错误消息并终止流
output = {"error": str(e), "error_type": e.__class__.__name__}
return output
@@ -76,12 +83,14 @@ async def generate(request: dict):
yield f"data: {json.dumps(result)}\n\n"
except Exception as e:
# 记录完整的异常堆栈信息
api_server_logger.error(f"Error during generation: {str(e)}", exc_info=True)
api_server_logger.error(f"Error during generation: {e!s}", exc_info=True)
# 返回结构化的错误消息并终止流
error_msg = {"error": str(e), "error_type": e.__class__.__name__}
yield f"data: {json.dumps(error_msg)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
def launch_api_server(args) -> None:
"""
启动http服务
@@ -97,11 +106,13 @@ def launch_api_server(args) -> None:
return
try:
uvicorn.run(app=app,
uvicorn.run(
app=app,
host=args.host,
port=args.port,
workers=args.workers,
log_level="info") # set log level to error to avoid log
log_level="info",
) # set log level to error to avoid log
except Exception as e:
api_server_logger.error(f"launch sync http server error, {e}")

View File

@@ -14,35 +14,45 @@
# limitations under the License.
"""
from typing import Literal, Union, List
from typing_extensions import Required, TypedDict, TypeAlias
from openai.types.chat import ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
from urllib.parse import urlparse
import requests
from copy import deepcopy
from typing import List, Literal, Union
from urllib.parse import urlparse
import requests
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from typing_extensions import Required, TypeAlias, TypedDict
from fastdeploy.input.multimodal.video import VideoMediaIO
from fastdeploy.input.multimodal.image import ImageMediaIO
from fastdeploy.input.multimodal.video import VideoMediaIO
class VideoURL(TypedDict, total=False):
"""Video URL object"""
url: Required[str]
"""Either a URL of the video or the base64 encoded video data"""
class CustomChatCompletionContentPartVideoParam(TypedDict, total=False):
"""Custom Video URL object"""
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
"""The type of the content type."""
CustomChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, CustomChatCompletionContentPartVideoParam
OpenAIChatCompletionContentPartParam,
CustomChatCompletionContentPartVideoParam,
]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Custom User chat message parameter."""
@@ -58,11 +68,13 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
Provides the model information to differentiate between participants of the same role.
"""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam]
class MultiModalPartParser(object):
class MultiModalPartParser:
"""Multi Modal Part parser"""
def __init__(self):
self.image_io = ImageMediaIO()
self.video_io = VideoMediaIO()
@@ -92,6 +104,7 @@ class MultiModalPartParser(object):
localpath = parsed.path
return media_io.load_file(localpath)
def parse_content_part(mm_parser, part):
"""only support openai compatible format for now"""
@@ -120,6 +133,7 @@ def parse_content_part(mm_parser, part):
raise ValueError(f"Unknown content part type: {part_type}")
# TODO async
# def parse_chat_messages(messages: List[ChatCompletionMessageParam]):
def parse_chat_messages(messages):

View File

@@ -14,17 +14,15 @@
# limitations under the License.
"""
import zmq
import time
from random import randint
import uuid
import numpy as np
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import ZmqClient, IPCSignal
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, EngineError
from fastdeploy.utils import EngineError, api_server_logger
class EngineClient:
@@ -32,23 +30,36 @@ class EngineClient:
EngineClient is a class that handles the communication between the client and the server.
"""
def __init__(self, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm_per_prompt, mm_processor_kwargs,
enable_mm=False, reasoning_parser=None):
input_processor = InputPreprocessor(tokenizer,
def __init__(
self,
tokenizer,
max_model_len,
tensor_parallel_size,
pid,
limit_mm_per_prompt,
mm_processor_kwargs,
enable_mm=False,
reasoning_parser=None,
):
input_processor = InputPreprocessor(
tokenizer,
reasoning_parser,
limit_mm_per_prompt,
mm_processor_kwargs,
enable_mm)
enable_mm,
)
self.enable_mm = enable_mm
self.reasoning_parser = reasoning_parser
self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(name="worker_healthy_live_signal",
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=self.worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=pid,
create=False)
create=False,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
@@ -56,7 +67,8 @@ class EngineClient:
array=model_weights_status,
dtype=np.int32,
suffix=pid,
create=False)
create=False,
)
def create_zmq_client(self, model, mode):
"""
@@ -75,7 +87,6 @@ class EngineClient:
if "request_id" not in prompts:
request_id = str(uuid.uuid4())
prompts["request_id"] = request_id
query_list = []
if "max_tokens" not in prompts:
prompts["max_tokens"] = self.max_model_len - 1
@@ -105,8 +116,8 @@ class EngineClient:
if task.get("reasoning_max_tokens", None) is None:
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
min_tokens = task.get("min_tokens", 1)
if 'messages' in task:
del task['messages']
if "messages" in task:
del task["messages"]
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")
work_process_metrics.request_params_max_tokens.observe(task["max_tokens"])
work_process_metrics.prompt_tokens_total.inc(input_ids_len)
@@ -133,8 +144,7 @@ class EngineClient:
task["preprocess_end_time"] = time.time()
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
api_server_logger.info(
f"Cache request with request_id ({task.get('request_id')}), "
f"cost {time.time() - preprocess_cost_time}"
f"Cache request with request_id ({task.get('request_id')}), " f"cost {time.time() - preprocess_cost_time}"
)
self.vaild_parameters(task)
@@ -153,7 +163,6 @@ class EngineClient:
Validate stream options
"""
if data.get("n"):
if data["n"] != 1:
raise ValueError("n only support 1.")
@@ -168,9 +177,7 @@ class EngineClient:
if data.get("top_p"):
if data["top_p"] > 1 or data["top_p"] < 0:
raise ValueError(
"top_p value can only be defined [0, 1].")
raise ValueError("top_p value can only be defined [0, 1].")
if data.get("frequency_penalty"):
if not -2.0 <= data["frequency_penalty"] <= 2.0:
@@ -178,24 +185,18 @@ class EngineClient:
if data.get("temperature"):
if data["temperature"] < 0:
raise ValueError(f"temperature must be non-negative")
raise ValueError("temperature must be non-negative")
if data.get("presence_penalty"):
if not -2.0 <= data["presence_penalty"] <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2]")
if data.get("seed"):
if not 0 <= data["seed"] <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580]")
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when `stream=True`.")
raise ValueError("Stream options can only be defined when `stream=True`.")
def check_health(self, time_interval_threashold=30):
"""
@@ -209,7 +210,6 @@ class EngineClient:
return True, ""
def is_workers_alive(self):
"""
Check the health of the model server by checking whether all workers are alive.
@@ -220,8 +220,6 @@ class EngineClient:
else:
return False, "No model weight enabled"
def update_model_weight(self, timeout=300):
"""
Update the model weight by sending a signal to the server.
@@ -244,8 +242,6 @@ class EngineClient:
time.sleep(1)
return True, ""
def clear_load_weight(self, timeout=300):
"""
Clear the load weight status.

View File

@@ -28,6 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server
@@ -78,16 +79,14 @@ class LLM:
# Create the Engine
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
self.default_sampling_params = SamplingParams(
max_tokens=self.llm_engine.cfg.max_model_len)
self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.max_model_len)
self.llm_engine.start()
self.mutex = threading.Lock()
self.req_output = dict()
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(
target=self._receive_output, daemon=True)
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
def _check_master(self):
@@ -111,15 +110,19 @@ class LLM:
continue
self.req_output[request_id].add(result)
except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format(
e, str(traceback.format_exc())))
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def generate(
self,
prompts: Union[str, list[str], list[int], list[list[int]],
dict[str, Any], list[dict[str, Any]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompts: Union[
str,
list[str],
list[int],
list[list[int]],
dict[str, Any],
list[dict[str, Any]],
],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
):
"""
@@ -161,11 +164,9 @@ class LLM:
# sampling_params = None
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
raise ValueError(
"prompts and sampling_params must be the same length.")
raise ValueError("prompts and sampling_params must be the same length.")
req_ids = self._add_request(prompts=prompts,
sampling_params=sampling_params)
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -176,8 +177,7 @@ class LLM:
def chat(
self,
messages: Union[list[Any], list[list[Any]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
):
@@ -211,15 +211,16 @@ class LLM:
messages = [messages]
if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError(
"messages and sampling_params must be the same length.")
raise ValueError("messages and sampling_params must be the same length.")
messages_len = len(messages)
for i in range(messages_len):
messages[i] = {"messages": messages[i]}
req_ids = self._add_request(prompts=messages,
req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -253,8 +254,7 @@ class LLM:
"prompt": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], list) and isinstance(
prompts[i][0], int):
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
tasks = {
"prompt_token_ids": prompts[i],
"request_id": request_id,
@@ -273,11 +273,8 @@ class LLM:
current_sampling_params = sampling_params
enable_thinking = None
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get(
"enable_thinking", None)
self.llm_engine.add_requests(tasks,
current_sampling_params,
enable_thinking=enable_thinking)
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
@@ -303,8 +300,7 @@ class LLM:
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
)
output = [None] * num_requests
@@ -322,13 +318,11 @@ class LLM:
continue
result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response(
result)
result = self.llm_engine.data_processor.process_response(result)
output[pos] = result
finished.append(i)
llm_logger.debug(
"Request id: {} has been completed.".format(req_id))
llm_logger.debug(f"Request id: {req_id} has been completed.")
if use_tqdm:
pbar.update(1)
@@ -346,24 +340,27 @@ if __name__ == "__main__":
# llm = LLM(model="llama_model")
# output = llm.generate(prompts="who are you", use_tqdm=True)
# print(output)
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2)
llm = LLM(
model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2,
)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
output = llm.generate(prompts="who are you",
output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
print(output)
output = llm.generate(
prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1, max_tokens=50),
use_tqdm=True,
sampling_params=sampling_params)
)
print(output)
output = llm.generate(prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1,
max_tokens=50),
use_tqdm=True)
print(output)
output = llm.generate(prompts=["who are you", "I miss you"],
output = llm.generate(
prompts=["who are you", "I miss you"],
sampling_params=[
SamplingParams(temperature=1, max_tokens=50),
SamplingParams(temperature=1, max_tokens=20)
SamplingParams(temperature=1, max_tokens=20),
],
use_tqdm=True)
use_tqdm=True,
)
print(output)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import threading
import time
@@ -24,46 +25,41 @@ import zmq
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST
from fastdeploy.metrics.trace_util import inject_to_metadata,instrument
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ControlSchedulerRequest,
ErrorResponse,
ControlSchedulerRequest)
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import \
OpenAIServingCompletion
from fastdeploy.metrics.metrics import (EXCLUDE_LABELS,
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
from fastdeploy.metrics.metrics import (
EXCLUDE_LABELS,
cleanup_prometheus_files,
get_filtered_metrics,
main_process_metrics)
from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger,
console_logger, is_port_available,
retrive_model_from_server)
main_process_metrics,
)
from fastdeploy.metrics.trace_util import inject_to_metadata, instrument
from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
console_logger,
is_port_available,
retrive_model_from_server,
)
parser = FlexibleArgumentParser()
parser.add_argument("--port",
default=8000,
type=int,
help="port to the http server")
parser.add_argument("--host",
default="0.0.0.0",
type=str,
help="host to the http server")
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--metrics-port",
default=8001,
type=int,
help="port for metrics server")
parser.add_argument("--controller-port",
default=-1,
type=int,
help="port for controller server")
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model)
@@ -79,26 +75,18 @@ def load_engine():
if llm_engine is not None:
return llm_engine
api_server_logger.info(
f"FastDeploy LLM API server starting... {os.getpid()}")
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error(
"Failed to initialize FastDeploy LLM engine, service exit now!")
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return None
api_server_logger.info("FastDeploy LLM engine initialized!\n")
console_logger.info(
f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics"
)
console_logger.info(
f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions"
)
console_logger.info(
f"Launching completion service at http://{args.host}:{args.port}/v1/completions"
)
console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
llm_engine = engine
return engine
@@ -111,16 +99,21 @@ async def lifespan(app: FastAPI):
if args.tokenizer is None:
args.tokenizer = args.model
if current_process().name != 'MainProcess':
if current_process().name != "MainProcess":
pid = os.getppid()
else:
pid = os.getpid()
api_server_logger.info(f"{pid}")
engine_client = EngineClient(args.tokenizer, args.max_model_len,
args.tensor_parallel_size, pid,
engine_client = EngineClient(
args.tokenizer,
args.max_model_len,
args.tensor_parallel_size,
pid,
args.limit_mm_per_prompt,
args.mm_processor_kwargs, args.enable_mm,
args.reasoning_parser)
args.mm_processor_kwargs,
args.enable_mm,
args.reasoning_parser,
)
app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip)
@@ -134,6 +127,7 @@ async def lifespan(app: FastAPI):
try:
engine_client.zmq_client.close()
from prometheus_client import multiprocess
multiprocess.mark_process_dead(os.getpid())
api_server_logger.info(f"Closing metrics client pid: {pid}")
except Exception as e:
@@ -187,11 +181,7 @@ async def list_all_routes():
if route.path.startswith("/v1"):
methods = sorted(route.methods)
tags = getattr(route, "tags", []) or []
routes_info.append({
"path": route.path,
"methods": methods,
"tags": tags
})
routes_info.append({"path": route.path, "methods": methods, "tags": tags})
return {"routes": routes_info}
@@ -209,15 +199,12 @@ async def create_chat_completion(request: ChatCompletionRequest):
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return JSONResponse(
content={"error": "Worker Service Not Healthy"},
status_code=304)
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
inject_to_metadata(request)
generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
@@ -233,14 +220,11 @@ async def create_completion(request: CompletionRequest):
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return JSONResponse(
content={"error": "Worker Service Not Healthy"},
status_code=304)
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
@@ -258,8 +242,7 @@ def update_model_weight(request: Request) -> Response:
return Response(content=msg, status_code=404)
return Response(status_code=200)
else:
return Response(content="Dynamic Load Weight Disabled.",
status_code=404)
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
@app.get("/clear_load_weight")
@@ -273,8 +256,7 @@ def clear_load_weight(request: Request) -> Response:
return Response(content=msg, status_code=404)
return Response(status_code=200)
else:
return Response(content="Dynamic Load Weight Disabled.",
status_code=404)
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
def launch_api_server() -> None:
@@ -284,16 +266,17 @@ def launch_api_server() -> None:
if not is_port_available(args.host, args.port):
raise Exception(f"The parameter `port`:{args.port} is already in use.")
api_server_logger.info(
f"launch Fastdeploy api server... port: {args.port}")
api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}")
api_server_logger.info(f"args: {args.__dict__}")
try:
uvicorn.run(app="fastdeploy.entrypoints.openai.api_server:app",
uvicorn.run(
app="fastdeploy.entrypoints.openai.api_server:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level="info") # set log level to error to avoid log
log_level="info",
) # set log level to error to avoid log
except Exception as e:
api_server_logger.error(f"launch sync http server error, {e}")
@@ -308,8 +291,8 @@ async def metrics():
"""
metrics_text = get_filtered_metrics(
EXCLUDE_LABELS,
extra_register_func=lambda reg: main_process_metrics.register_all(
reg, workers=args.workers))
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers),
)
return Response(metrics_text, media_type=CONTENT_TYPE_LATEST)
@@ -318,23 +301,17 @@ def run_metrics_server():
run metrics server
"""
uvicorn.run(metrics_app,
host="0.0.0.0",
port=args.metrics_port,
log_level="error")
uvicorn.run(metrics_app, host="0.0.0.0", port=args.metrics_port, log_level="error")
def launch_metrics_server():
"""Metrics server running the sub thread"""
if not is_port_available(args.host, args.metrics_port):
raise Exception(
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
)
raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.")
prom_dir = cleanup_prometheus_files(True)
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
metrics_server_thread = threading.Thread(target=run_metrics_server,
daemon=True)
metrics_server_thread = threading.Thread(target=run_metrics_server, daemon=True)
metrics_server_thread.start()
time.sleep(1)
@@ -375,7 +352,8 @@ def control_scheduler(request: ControlSchedulerRequest):
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
llm_engine.scheduler.update_config(
load_shards_num=request.load_shards_num,
reallocate=request.reallocate_shard)
reallocate=request.reallocate_shard,
)
else:
content.message = "This scheduler doesn't support the `update_config()` method."
content.code = 400
@@ -388,10 +366,12 @@ def run_controller_server():
"""
run controller server
"""
uvicorn.run(controller_app,
uvicorn.run(
controller_app,
host="0.0.0.0",
port=args.controller_port,
log_level="error")
log_level="error",
)
def launch_controller_server():
@@ -400,12 +380,9 @@ def launch_controller_server():
return
if not is_port_available(args.host, args.controller_port):
raise Exception(
f"The parameter `controller_port`:{args.controller_port} is already in use."
)
raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.")
controller_server_thread = threading.Thread(target=run_controller_server,
daemon=True)
controller_server_thread = threading.Thread(target=run_controller_server, daemon=True)
controller_server_thread.start()
time.sleep(1)

View File

@@ -30,6 +30,7 @@ class ErrorResponse(BaseModel):
"""
Error response from OpenAI API.
"""
object: str = "error"
message: str
code: int
@@ -39,6 +40,7 @@ class PromptTokenUsageInfo(BaseModel):
"""
Prompt-related token usage info.
"""
cached_tokens: Optional[int] = None
@@ -46,6 +48,7 @@ class UsageInfo(BaseModel):
"""
Usage info for a single request.
"""
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
@@ -56,6 +59,7 @@ class FunctionCall(BaseModel):
"""
Function call.
"""
name: str
arguments: str
@@ -64,6 +68,7 @@ class ToolCall(BaseModel):
"""
Tool call.
"""
id: str = None
type: Literal["function"] = "function"
function: FunctionCall
@@ -74,6 +79,7 @@ class DeltaFunctionCall(BaseModel):
"""
Delta function call.
"""
name: Optional[str] = None
arguments: Optional[str] = None
@@ -83,6 +89,7 @@ class DeltaToolCall(BaseModel):
"""
Delta tool call.
"""
id: Optional[str] = None
type: Optional[Literal["function"]] = None
index: int
@@ -93,6 +100,7 @@ class FunctionDefinition(BaseModel):
"""
Function definition.
"""
name: str
description: Optional[str] = None
parameters: Optional[dict[str, Any]] = None
@@ -102,6 +110,7 @@ class ChatCompletionToolsParam(BaseModel):
"""
Chat completion tools parameter.
"""
type: Literal["function"] = "function"
function: FunctionDefinition
@@ -110,6 +119,7 @@ class ChatMessage(BaseModel):
"""
Chat message.
"""
role: str
content: str
reasoning_content: Optional[str] = None
@@ -120,6 +130,7 @@ class ChatCompletionResponseChoice(BaseModel):
"""
Chat completion response choice.
"""
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
@@ -130,6 +141,7 @@ class ChatCompletionResponse(BaseModel):
"""
Chat completion response.
"""
id: str
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -137,26 +149,32 @@ class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class LogProbEntry(BaseModel):
"""
Log probability entry.
"""
token: str
logprob: float
bytes: Optional[List[int]] = None
top_logprobs: Optional[List["LogProbEntry"]] = None
top_logprobs: Optional[List[LogProbEntry]] = None
class LogProbs(BaseModel):
"""
LogProbs.
"""
content: Optional[List[LogProbEntry]] = None
refusal: Optional[Union[str, None]] = None
class DeltaMessage(BaseModel):
"""
Delta message for chat completion stream response.
"""
role: Optional[str] = None
content: Optional[str] = None
token_ids: Optional[List[int]] = None
@@ -168,6 +186,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
"""
Chat completion response choice for stream response.
"""
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
@@ -179,6 +198,7 @@ class ChatCompletionStreamResponse(BaseModel):
"""
Chat completion response for stream response.
"""
id: str
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -191,6 +211,7 @@ class CompletionResponseChoice(BaseModel):
"""
Completion response choice.
"""
index: int
text: str
token_ids: Optional[List[int]] = None
@@ -205,6 +226,7 @@ class CompletionResponse(BaseModel):
"""
Completion response.
"""
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -217,6 +239,7 @@ class CompletionResponseStreamChoice(BaseModel):
"""
Completion response choice for stream response.
"""
index: int
text: str
arrival_time: float = None
@@ -231,6 +254,7 @@ class CompletionStreamResponse(BaseModel):
"""
Completion response for stream response.
"""
id: str
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -243,6 +267,7 @@ class StreamOptions(BaseModel):
"""
Stream options.
"""
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = False
@@ -251,9 +276,9 @@ class StructuralTag(BaseModel):
"""
Structural tag.
"""
begin: str
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
alias="schema")
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
end: str
@@ -261,9 +286,10 @@ class JsonSchemaResponseFormat(BaseModel):
"""
Json schema for ResponseFormat.
"""
name: str
description: Optional[str] = None
json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema')
json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema")
strict: Optional[bool] = None
@@ -271,6 +297,7 @@ class StructuralTagResponseFormat(BaseModel):
"""
Structural tag for ResponseFormat.
"""
type: Literal["structural_tag"]
structures: list[StructuralTag]
triggers: list[str]
@@ -280,6 +307,7 @@ class ResponseFormat(BaseModel):
"""
response_format type.
"""
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
@@ -291,6 +319,7 @@ class CompletionRequest(BaseModel):
"""
Completion request to the engine.
"""
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = "default"
@@ -333,7 +362,7 @@ class CompletionRequest(BaseModel):
"""
req_dict = {}
if request_id is not None:
req_dict['request_id'] = request_id
req_dict["request_id"] = request_id
for key, value in self.dict().items():
if value is not None:
req_dict[key] = value
@@ -341,7 +370,7 @@ class CompletionRequest(BaseModel):
for key, value in self.suffix.items():
req_dict[key] = value
if prompt is not None:
req_dict['prompt'] = prompt
req_dict["prompt"] = prompt
if isinstance(prompt[0], int):
req_dict["prompt_token_ids"] = prompt
@@ -363,8 +392,11 @@ class CompletionRequest(BaseModel):
req_dict["guided_json_object"] = guided_json_object
guided_schema = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
"structural_tag"
"guided_json",
"guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
]
for key in guided_schema:
item = getattr(self, key, None)
@@ -380,15 +412,16 @@ class CompletionRequest(BaseModel):
Validate stream options
"""
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when `stream=True`.")
raise ValueError("Stream options can only be defined when `stream=True`.")
guided_count = sum([
guided_count = sum(
[
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None,
"guided_grammar" in data and data["guided_grammar"] is not None
])
"guided_grammar" in data and data["guided_grammar"] is not None,
]
)
if guided_count > 1:
raise ValueError(
@@ -403,6 +436,7 @@ class ChatCompletionRequest(BaseModel):
"""
Chat completion request to the engine.
"""
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: Union[List[Any], List[int]]
@@ -414,8 +448,8 @@ class ChatCompletionRequest(BaseModel):
# remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated=
'max_tokens is deprecated in favor of the max_completion_tokens field')
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
)
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = None
@@ -451,7 +485,7 @@ class ChatCompletionRequest(BaseModel):
"""
req_dict = {}
if request_id is not None:
req_dict['request_id'] = request_id
req_dict["request_id"] = request_id
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
@@ -483,17 +517,18 @@ class ChatCompletionRequest(BaseModel):
self.guided_json = json_schema
elif self.response_format.type == "structural_tag":
structural_tag = self.response_format
assert structural_tag is not None and isinstance(
structural_tag, StructuralTagResponseFormat)
self.structural_tag = json.dumps(
structural_tag.model_dump(by_alias=True))
assert structural_tag is not None and isinstance(structural_tag, StructuralTagResponseFormat)
self.structural_tag = json.dumps(structural_tag.model_dump(by_alias=True))
if guided_json_object:
req_dict["guided_json_object"] = guided_json_object
guided_schema = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
"structural_tag"
"guided_json",
"guided_regex",
"guided_choice",
"guided_grammar",
"structural_tag",
]
for key in guided_schema:
item = getattr(self, key, None)
@@ -509,16 +544,17 @@ class ChatCompletionRequest(BaseModel):
Validate stream options
"""
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when `stream=True`.")
raise ValueError("Stream options can only be defined when `stream=True`.")
guided_count = sum([
guided_count = sum(
[
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None,
"guided_grammar" in data and data["guided_grammar"] is not None,
"structural_tag" in data and data["structural_tag"] is not None
])
"structural_tag" in data and data["structural_tag"] is not None,
]
)
if guided_count > 1:
raise ValueError(
@@ -537,9 +573,7 @@ class ChatCompletionRequest(BaseModel):
raise ValueError("`top_logprobs` must be a positive value.")
if top_logprobs > 0 and not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
return data
@@ -548,6 +582,7 @@ class ControlSchedulerRequest(BaseModel):
"""
Control scheduler request to the engine.
"""
reset: Optional[bool] = False
load_shards_num: Optional[int] = None
reallocate_shard: Optional[bool] = False

View File

@@ -15,21 +15,29 @@
"""
import asyncio
import json
import time
import traceback
import uuid
from typing import List, Optional
import msgpack
import aiozmq
import msgpack
from aiozmq import zmq
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo)
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorResponse,
LogProbEntry,
LogProbs,
PromptTokenUsageInfo,
UsageInfo,
)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.worker.output import LogprobsLists
@@ -53,10 +61,7 @@ class OpenAIServingChat:
return True
return False
async def create_chat_completion(
self,
request: ChatCompletionRequest
):
async def create_chat_completion(self, request: ChatCompletionRequest):
"""
Create a new chat completion using the specified parameters.
"""
@@ -81,16 +86,10 @@ class OpenAIServingChat:
del current_req_dict
if request.stream:
return self.chat_completion_stream_generator(
request, request_id,
request.model,
prompt_token_ids)
return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids)
else:
try:
return await self.chat_completion_full_generator(
request, request_id,
request.model,
prompt_token_ids)
return await self.chat_completion_full_generator(request, request_id, request.model, prompt_token_ids)
except Exception as e:
return ErrorResponse(code=400, message=str(e))
@@ -106,7 +105,7 @@ class OpenAIServingChat:
request: ChatCompletionRequest,
request_id: str,
model_name: str,
prompt_token_ids: list()
prompt_token_ids: list(),
):
"""
Streaming chat completion generator.
@@ -135,14 +134,11 @@ class OpenAIServingChat:
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name
model=model_name,
)
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer.write([b"", request_id.encode('utf-8')])
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
dealer.write([b"", request_id.encode("utf-8")])
choices = []
current_waiting_time = 0
if request.metadata is not None:
@@ -171,20 +167,29 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
res,
stream=True,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
if res['metrics']['first_token_time'] is not None:
arrival_time = res['metrics']['first_token_time']
inference_start_time = res['metrics']['inference_start_time']
if res["metrics"]["first_token_time"] is not None:
arrival_time = res["metrics"]["first_token_time"]
inference_start_time = res["metrics"]["inference_start_time"]
else:
arrival_time = res['metrics']['arrival_time'] - inference_start_time
arrival_time = res["metrics"]["arrival_time"] - inference_start_time
if first_iteration:
num_prompt_tokens = len(prompt_token_ids)
num_cached_tokens = res.get("num_cached_tokens", 0)
for i in range(num_choices):
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
delta=DeltaMessage(
role="assistant",
content="",
reasoning_content="",
tool_calls=None,
),
)
if request.metadata is not None and request.metadata.get("training", False):
choice.delta.token_ids = prompt_token_ids
@@ -193,14 +198,14 @@ class OpenAIServingChat:
object=chunk_object_type,
created=created_time,
choices=[choice],
model=model_name
model=model_name,
)
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens),
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
first_iteration = False
@@ -222,24 +227,32 @@ class OpenAIServingChat:
)
previous_num_tokens += len(output["token_ids"])
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
delta_message = DeltaMessage(
content=delta_text,
reasoning_content=output.get("reasoning_content"),
token_ids=output.get("token_ids"),
tool_calls=output.get("tool_call_content", []),
)
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=delta_message,
logprobs=logprobs_res,
arrival_time=arrival_time
arrival_time=arrival_time,
)
if res["finished"]:
num_choices -= 1
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
work_process_metrics.e2e_request_latency.observe(
time.time() - res["metrics"]["request_start_time"]
)
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \
output.get("finish_reason", "") == "tool_calls":
if (
self.engine_client.reasoning_parser == "ernie_x1"
and output.get("finish_reason", "") == "tool_calls"
):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
@@ -253,7 +266,7 @@ class OpenAIServingChat:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=previous_num_tokens,
total_tokens=num_prompt_tokens + previous_num_tokens
total_tokens=num_prompt_tokens + previous_num_tokens,
)
choices.append(choice)
@@ -267,13 +280,12 @@ class OpenAIServingChat:
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = []
if include_usage:
completion_tokens = previous_num_tokens
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens
total_tokens=num_prompt_tokens + completion_tokens,
)
chunk = ChatCompletionStreamResponse(
id=request_id,
@@ -281,7 +293,7 @@ class OpenAIServingChat:
created=created_time,
choices=[],
model=model_name,
usage=usage
usage=usage,
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
@@ -297,7 +309,7 @@ class OpenAIServingChat:
request: ChatCompletionRequest,
request_id: str,
model_name: str,
prompt_token_ids: list()
prompt_token_ids: list(),
):
"""
Full chat completion generator.
@@ -307,11 +319,8 @@ class OpenAIServingChat:
enable_thinking = None
include_stop_str_in_output = False
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer.write([b"", request_id.encode('utf-8')])
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
dealer.write([b"", request_id.encode("utf-8")])
final_res = None
previous_num_tokens = 0
current_waiting_time = 0
@@ -340,7 +349,11 @@ class OpenAIServingChat:
enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
data,
stream=False,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
# api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response
@@ -375,26 +388,23 @@ class OpenAIServingChat:
content=output["text"],
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"),
token_ids=output.get("token_ids")
token_ids=output.get("token_ids"),
)
logprobs_full_res = None
if logprob_contents:
logprobs_full_res = LogProbs(
content=logprob_contents
)
logprobs_full_res = LogProbs(content=logprob_contents)
choice = ChatCompletionResponseChoice(
index=0,
message=message,
logprobs=logprobs_full_res,
finish_reason=None
finish_reason=None,
)
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \
output.get("finish_reason", "") == "tool_calls":
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
@@ -409,7 +419,7 @@ class OpenAIServingChat:
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0))
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0)),
)
work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"])
return ChatCompletionResponse(
@@ -417,7 +427,7 @@ class OpenAIServingChat:
created=created_time,
model=model_name,
choices=choices,
usage=usage
usage=usage,
)
def build_logprobs_response(
@@ -454,8 +464,9 @@ class OpenAIServingChat:
# Construct the candidate token structure (LogProbEntry) of topk
top_logprob_entries: List[LogProbEntry] = []
for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response([tid],
clean_up_tokenization_spaces=False)
token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False
)
# token_bytes = token_str.encode("utf-8", errors="replace")
entry = LogProbEntry(
token=token_str,
@@ -468,7 +479,7 @@ class OpenAIServingChat:
token=top_logprob_entries[0].token,
logprob=top_logprob_entries[0].logprob,
bytes=top_logprob_entries[0].bytes,
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates
top_logprobs=top_logprob_entries[1:], # Here are the complete topk candidates
)
return LogProbs(content=[sampled_entry])

View File

@@ -15,33 +15,25 @@
"""
import asyncio
import time
import uuid
from typing import List
import aiozmq
import json
import msgpack
from aiozmq import zmq
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast, TypeVar, List
import uuid
from fastapi import Request
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
ErrorResponse,
CompletionRequest,
CompletionResponse,
CompletionStreamResponse,
CompletionResponseStreamChoice,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
UsageInfo,
DeltaToolCall,
DeltaFunctionCall,
ToolCall,
FunctionCall
)
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.engine.request import RequestOutput
class OpenAIServingCompletion:
@@ -105,9 +97,7 @@ class OpenAIServingCompletion:
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
try:
current_req_dict["arrival_time"] = time.time()
prompt_batched_token_ids.append(
self.engine_client.format_and_add_data(current_req_dict)
)
prompt_batched_token_ids.append(self.engine_client.format_and_add_data(current_req_dict))
except Exception as e:
return ErrorResponse(message=str(e), code=400)
@@ -120,7 +110,7 @@ class OpenAIServingCompletion:
request_id=request_id,
created_time=created_time,
model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids
prompt_batched_token_ids=prompt_batched_token_ids,
)
else:
try:
@@ -130,7 +120,7 @@ class OpenAIServingCompletion:
request_id=request_id,
created_time=created_time,
model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids
prompt_batched_token_ids=prompt_batched_token_ids,
)
except Exception as e:
return ErrorResponse(code=400, message=str(e))
@@ -138,7 +128,6 @@ class OpenAIServingCompletion:
except Exception as e:
return ErrorResponse(message=str(e), code=400)
async def completion_full_generator(
self,
request: CompletionRequest,
@@ -146,7 +135,7 @@ class OpenAIServingCompletion:
request_id: str,
created_time: int,
model_name: str,
prompt_batched_token_ids: list()
prompt_batched_token_ids: list(),
):
"""
Process the full completion request with multiple choices.
@@ -155,10 +144,7 @@ class OpenAIServingCompletion:
try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
# create dealer
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
@@ -186,8 +172,7 @@ class OpenAIServingCompletion:
if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"]))
self.engine_client.data_processor.process_response_dict(
data, stream=False)
self.engine_client.data_processor.process_response_dict(data, stream=False)
output_tokens[rid] += len(data["outputs"]["token_ids"])
if data.get("finished", False):
data["output_token_ids"] = output_tokens[rid]
@@ -201,18 +186,15 @@ class OpenAIServingCompletion:
request_id=request_id,
created_time=created_time,
model_name=model_name,
prompt_batched_token_ids=prompt_batched_token_ids
prompt_batched_token_ids=prompt_batched_token_ids,
)
except Exception as e:
api_server_logger.error(
f"Error in completion_full_generator: {e}", exc_info=True
)
api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True)
raise
finally:
if dealer is not None:
dealer.close()
async def completion_stream_generator(
self,
request: CompletionRequest,
@@ -220,20 +202,17 @@ class OpenAIServingCompletion:
request_id: str,
created_time: int,
model_name: str,
prompt_batched_token_ids: list()
prompt_batched_token_ids: list(),
):
"""
Process the stream completion request.
"""
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc"
)
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
for i in range(num_choices):
req_id = f"{request_id}-{i}"
dealer.write([b"", req_id.encode('utf-8')]) # 发送多路请求
dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求
output_tokens = [0] * num_choices
inference_start_time = [0] * num_choices
first_iteration = [True] * num_choices
@@ -245,7 +224,7 @@ class OpenAIServingCompletion:
id=request_id,
created=created_time,
model=model_name,
choices=choices
choices=choices,
)
current_waiting_time = 0
@@ -264,7 +243,6 @@ class OpenAIServingCompletion:
await asyncio.sleep(0.1)
continue
response = msgpack.unpackb(raw_data[-1])
for res in response:
idx = int(res["request_id"].split("-")[-1])
@@ -277,39 +255,43 @@ class OpenAIServingCompletion:
id=request_id,
created=created_time,
model=model_name,
choices=[CompletionResponseStreamChoice(
choices=[
CompletionResponseStreamChoice(
index=idx,
text="",
token_ids=list(prompt_batched_token_ids[idx])
)]
token_ids=list(prompt_batched_token_ids[idx]),
)
],
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
first_iteration[idx] = False
self.engine_client.data_processor.process_response_dict(
res, stream=True)
if res['metrics'].get('first_token_time') is not None:
arrival_time = res['metrics']['first_token_time']
inference_start_time[idx] = res['metrics']['inference_start_time']
self.engine_client.data_processor.process_response_dict(res, stream=True)
if res["metrics"].get("first_token_time") is not None:
arrival_time = res["metrics"]["first_token_time"]
inference_start_time[idx] = res["metrics"]["inference_start_time"]
else:
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
output = res["outputs"]
choices.append(CompletionResponseStreamChoice(
choices.append(
CompletionResponseStreamChoice(
index=idx,
text=output["text"],
token_ids=output.get("token_ids"),
tool_calls=output.get("tool_call_content"),
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time
))
arrival_time=arrival_time,
)
)
if res["finished"]:
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
chunk.choices[0].finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and \
output.get("finish_reason", "") == "tool_calls":
if (
self.engine_client.reasoning_parser == "ernie_x1"
and output.get("finish_reason", "") == "tool_calls"
):
chunk.choices[0].finish_reason = "tool_calls"
else:
chunk.choices[0].finish_reason = "length"
@@ -321,12 +303,11 @@ class OpenAIServingCompletion:
id=request_id,
created=created_time,
model=model_name,
choices=choices
choices=choices,
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = []
if res["finished"]:
num_choices -= 1
if getattr(request, "stream_options", None) and request.stream_options.include_usage:
@@ -337,8 +318,8 @@ class OpenAIServingCompletion:
choices=[],
usage=UsageInfo(
prompt_tokens=len(prompt_batched_token_ids[idx]),
completion_tokens=output_tokens[idx]
)
completion_tokens=output_tokens[idx],
),
)
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
if choices:
@@ -346,7 +327,6 @@ class OpenAIServingCompletion:
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
choices = []
except Exception as e:
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
finally:
@@ -355,7 +335,6 @@ class OpenAIServingCompletion:
dealer.close()
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
@@ -363,7 +342,7 @@ class OpenAIServingCompletion:
request_id: str,
created_time: int,
model_name: str,
prompt_batched_token_ids: list()
prompt_batched_token_ids: list(),
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@@ -389,12 +368,13 @@ class OpenAIServingCompletion:
output_text = output["text"]
choice_data = CompletionResponseChoice(
token_ids=token_ids,
index=len(choices),
text=output_text,
reasoning_content=output.get('reasoning_content'),
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"),
logprobs=None,
finish_reason=None
finish_reason=None,
)
choices.append(choice_data)

View File

@@ -42,7 +42,7 @@ response = client.completions.create(
)
for chunk in response:
print(chunk.choices[0].text, end='')
print(chunk.choices[0].text, end="")
print("\n")
# Chat completion
@@ -76,5 +76,5 @@ response = client.chat.completions.create(
for chunk in response:
if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta, end='')
print(chunk.choices[0].delta, end="")
print("\n")

View File

@@ -20,115 +20,62 @@ from typing import Any, Callable
environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use BF16 on CPU.
"FD_CPU_USE_BF16":
lambda: os.getenv("FD_CPU_USE_BF16", "False"),
"FD_CPU_USE_BF16": lambda: os.getenv("FD_CPU_USE_BF16", "False"),
# Cuda architecture to build FastDeploy.This is a list of strings
# such as [80,90].
"FD_BUILDING_ARCS":
lambda: os.getenv("FD_BUILDING_ARCS", "[]"),
"FD_BUILDING_ARCS": lambda: os.getenv("FD_BUILDING_ARCS", "[]"),
# Log directory.
"FD_LOG_DIR":
lambda: os.getenv("FD_LOG_DIR", "log"),
"FD_LOG_DIR": lambda: os.getenv("FD_LOG_DIR", "log"),
# Whether to use debug mode, can set 0 or 1
"FD_DEBUG":
lambda: os.getenv("FD_DEBUG", "0"),
"FD_DEBUG": lambda: os.getenv("FD_DEBUG", "0"),
# Number of days to keep fastdeploy logs.
"FD_LOG_BACKUP_COUNT":
lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
"FD_LOG_BACKUP_COUNT": lambda: os.getenv("FD_LOG_BACKUP_COUNT", "7"),
# Model download cache directory.
"FD_MODEL_CACHE":
lambda: os.getenv("FD_MODEL_CACHE", None),
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
# Maximum number of stop sequences.
"FD_MAX_STOP_SEQS_NUM":
lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
"FD_MAX_STOP_SEQS_NUM": lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
# Maximum length of stop sequences.
"FD_STOP_SEQS_MAX_LEN":
lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
"FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
# GPU devices that will be used. This is a string that
# splited by comma, such as 0,1,2.
"CUDA_VISIBLE_DEVICES":
lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
"CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
# Whether to use HuggingFace tokenizer.
"FD_USE_HF_TOKENIZER":
lambda: os.getenv("FD_USE_HF_TOKENIZER", 0),
"FD_USE_HF_TOKENIZER": lambda: os.getenv("FD_USE_HF_TOKENIZER", 0),
# Set the high watermark (HWM) for receiving data during ZMQ initialization
"FD_ZMQ_SNDHWM":
lambda: os.getenv("FD_ZMQ_SNDHWM", 10000),
"FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 10000),
# cache kv quant params directory
"FD_CACHE_PARAMS":
lambda: os.getenv("FD_CACHE_PARAMS", "none"),
"FD_CACHE_PARAMS": lambda: os.getenv("FD_CACHE_PARAMS", "none"),
# Set attention backend. "NATIVE_ATTN", "APPEND_ATTN"
# and "MLA_ATTN" can be set currently.
"FD_ATTENTION_BACKEND":
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
"FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
"FD_SAMPLING_CLASS":
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
# Set moe backend."cutlass","marlin" and "triton" can be set currently.
"FD_MOE_BACKEND":
lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Set whether to disable recompute the request when the KV cache is full.
"FD_DISABLED_RECOVER":
lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
"FD_DISABLED_RECOVER": lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
# Set triton kernel JIT compilation directory.
"FD_TRITON_KERNEL_CACHE_DIR":
lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
"FD_TRITON_KERNEL_CACHE_DIR": lambda: os.getenv("FD_TRITON_KERNEL_CACHE_DIR", None),
# Whether transition from standalone PD decoupling to centralized inference
"FD_PD_CHANGEABLE":
lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
"FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "0"),
# Whether to use fastsafetensor load weight (0 or 1)
"FD_USE_FASTSAFETENSOR":
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
"FD_USE_FASTSAFETENSOR": lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
# Whether to use DeepGemm for FP8 blockwise MoE.
"FD_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
"FD_USE_DEEP_GEMM": lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
# Whether to use aggregate send.
"FD_USE_AGGREGATE_SEND":
lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
"FD_USE_AGGREGATE_SEND": lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
# Whether to open Trace.
"TRACES_ENABLE":
lambda: os.getenv("TRACES_ENABLE", "false"),
"TRACES_ENABLE": lambda: os.getenv("TRACES_ENABLE", "false"),
# set traec Server name.
"FD_SERVICE_NAME":
lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
"FD_SERVICE_NAME": lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
# set traec host name.
"FD_HOST_NAME":
lambda: os.getenv("FD_HOST_NAME", "localhost"),
"FD_HOST_NAME": lambda: os.getenv("FD_HOST_NAME", "localhost"),
# set traec exporter.
"TRACES_EXPORTER":
lambda: os.getenv("TRACES_EXPORTER", "console"),
"TRACES_EXPORTER": lambda: os.getenv("TRACES_EXPORTER", "console"),
# set traec exporter_otlp_endpoint.
"EXPORTER_OTLP_ENDPOINT":
lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
"EXPORTER_OTLP_ENDPOINT": lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
# set traec exporter_otlp_headers.
"EXPORTER_OTLP_HEADERS":
lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
}

View File

@@ -43,8 +43,7 @@ def import_custom_ops(package, module_name, global_ns):
logger.warning(f"Failed to import op {func_name}: {e}")
except Exception:
logger.warning(
f"Ops of {package} import failed, it may be not compiled.")
logger.warning(f"Ops of {package} import failed, it may be not compiled.")
preprocess_static_op(global_ns)

View File

@@ -69,12 +69,12 @@ class ErnieProcessor(BaseDataProcessor):
# Generation config
try:
self.generation_config = GenerationConfig.from_pretrained(
self.model_name_or_path)
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
except Exception as e:
data_processor_logger.warning(
f"Can't find generation config, so it will not use "
f"generation_config field in the model config, details={e}")
f"generation_config field in the model config, details={e}"
)
self.generation_config = None
def process_request(self, request, max_model_len=None, **kwargs):
@@ -89,8 +89,7 @@ class ErnieProcessor(BaseDataProcessor):
str: error message
"""
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(
request.eos_token_ids) == 0:
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
@@ -98,12 +97,9 @@ class ErnieProcessor(BaseDataProcessor):
request.set("stop_token_ids", stop_seqs)
request.set("stop_seqs_len", stop_seqs_len)
if request.prompt_token_ids is None or len(
request.prompt_token_ids) == 0:
system = request.get("system")
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is None and request.messages is None:
raise ValueError(
f"The request should have `input_ids`, `text` or `messages`: {request}.")
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
if request.prompt is not None or not request.raw_request:
prompt = request.prompt if request.prompt is not None else request.messages[0]
prompt = prompt[0] if isinstance(prompt, list) else prompt
@@ -114,14 +110,13 @@ class ErnieProcessor(BaseDataProcessor):
else:
request.prompt_token_ids = self.messages2ids(request.to_dict())
if max_model_len is not None and len(
request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[:
max_model_len -
1]
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
if request.get("max_tokens") is None:
request.set("max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)))
request.set(
"max_tokens",
max(1, max_model_len - len(request.prompt_token_ids)),
)
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request.set("temperature", 1)
@@ -140,45 +135,36 @@ class ErnieProcessor(BaseDataProcessor):
str: error message
"""
request = self._apply_default_parameters(request)
if not request.get('eos_token_ids'):
request['eos_token_ids'] = self.eos_token_ids
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
# 处理stop_sequences
stop_sequences = request.get('stop', [])
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request['stop_token_ids'] = stop_seqs
request['stop_seqs_len'] = stop_seqs_len
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
system = request.get("system")
# 处理prompt_token_ids
if not request.get('prompt_token_ids'):
if request.get('prompt') is None and request.get(
'messages') is None:
raise ValueError(
f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}"
)
if request.get('prompt'):
prompt = request.get('prompt')
if not request.get("prompt_token_ids"):
if request.get("prompt") is None and request.get("messages") is None:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
if request.get("prompt"):
prompt = request.get("prompt")
prompt = prompt[0] if isinstance(prompt, list) else prompt
tokens = self.tokenizer.tokenize(prompt)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request['prompt_token_ids'] = token_ids
request["prompt_token_ids"] = token_ids
req_id = request.get("request_id", None)
data_processor_logger.info(
f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}"
)
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
request['prompt_token_ids'] = self.messages2ids(request)
request["prompt_token_ids"] = self.messages2ids(request)
# 截断超过长度限制的prompt
if max_model_len is not None and len(
request['prompt_token_ids']) > max_model_len:
request['prompt_token_ids'] = request[
'prompt_token_ids'][:max_model_len - 1]
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None:
request["max_tokens"] = max(
1, max_model_len - len(request['prompt_token_ids']))
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
if request.get("temperature") < _SAMPLING_EPS:
# zero temperature is equivalent to greedy sampling
request["temperature"] = 1
@@ -200,22 +186,18 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict.request_id
token_ids = response_dict.outputs.token_ids
response_dict.usage = {
"completion_tokens": response_dict.outputs.index + 1
}
response_dict.usage = {"completion_tokens": response_dict.outputs.index + 1}
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
full_text = self.tokenizer.decode(token_ids)
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text, response_dict)
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict.outputs.text = text
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and \
response_dict.outputs.reasoning_content == "":
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
return response_dict
@@ -230,8 +212,7 @@ class ErnieProcessor(BaseDataProcessor):
Dict: response contain text fields
"""
if stream:
return self.process_response_dict_streaming(
response_dict, **kwargs)
return self.process_response_dict_streaming(response_dict, **kwargs)
else:
return self.process_response_dict_normal(response_dict, **kwargs)
@@ -255,16 +236,12 @@ class ErnieProcessor(BaseDataProcessor):
if is_end:
full_text = previous_texts + delta_text
if self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(
full_text, response_dict)
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = text
response_dict["outputs"][
"reasoning_content"] = reasoning_content
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
data_processor_logger.info(
f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}"
)
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
return response_dict
@@ -286,20 +263,22 @@ class ErnieProcessor(BaseDataProcessor):
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
token_ids, req_id)
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
if enable_thinking and self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts, previous_texts + delta_text, delta_text,
previous_token_ids, previous_token_ids + token_ids, token_ids)
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if is_end:
data_processor_logger.info(
f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}"
)
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
return response_dict
@@ -320,15 +299,15 @@ class ErnieProcessor(BaseDataProcessor):
request_or_messages,
tokenize=False,
split_special_tokens=False,
add_special_tokens=False)
add_special_tokens=False,
)
req_id = None
if isinstance(request_or_messages, dict):
req_id = request_or_messages.get("request_id", None)
tokens = self.tokenizer.tokenize(spliced_message)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
data_processor_logger.info(
f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
return token_ids
def ids2tokens(self, token_id, task_id):
@@ -352,7 +331,8 @@ class ErnieProcessor(BaseDataProcessor):
previous_token_ids = self.decode_status[task_id][2]
previous_texts = self.decode_status[task_id][3]
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(
previous_token_ids + token_id, prefix_offset, read_offset)
previous_token_ids + token_id, prefix_offset, read_offset
)
self.decode_status[task_id][0] = prefix_offset
self.decode_status[task_id][1] = read_offset
self.decode_status[task_id][2] += token_id
@@ -368,17 +348,15 @@ class ErnieProcessor(BaseDataProcessor):
tokenizer (AutoTokenizer)
"""
vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model"
"tokenizer.model",
"spm.model",
"ernie_token_100k.model",
]
for i in range(len(vocab_file_names)):
if os.path.exists(
os.path.join(self.model_name_or_path,
vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names[
"vocab_file"] = vocab_file_names[i]
if os.path.exists(os.path.join(self.model_name_or_path, vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
break
self.tokenizer = ErnieBotTokenizer.from_pretrained(
self.model_name_or_path)
self.tokenizer = ErnieBotTokenizer.from_pretrained(self.model_name_or_path)
def get_pad_id(self):
"""
@@ -391,16 +369,17 @@ class ErnieProcessor(BaseDataProcessor):
# return self.tokenizer.eos_token
return self.tokenizer.pad_token_id
def pad_batch_data(self,
def pad_batch_data(
self,
insts,
pad_id=0,
return_seq_len=False,
return_array=True,
pad_style="right"):
pad_style="right",
):
"""Pad the instances to the max sequence length in batch."""
if len(insts) == 0:
padded_insts = np.array([[]],
dtype=np.int64) if return_array else [[]]
padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
if return_seq_len:
seq_len = np.array([], dtype=np.int64) if return_array else []
return padded_insts, seq_len
@@ -408,15 +387,11 @@ class ErnieProcessor(BaseDataProcessor):
max_len = max(map(len, insts))
if pad_style == "left":
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst)
for inst in insts]
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
else:
padded_insts = [
list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts
]
padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
if return_array:
padded_insts = np.array(padded_insts,
dtype=np.int64).reshape([-1, max_len])
padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
if return_seq_len:
seq_len = [len(inst) for inst in insts]
@@ -432,15 +407,9 @@ class ErnieProcessor(BaseDataProcessor):
stop_seqs = []
for seq in stop_sequences:
if seq != self.tokenizer.eos_token_id:
stop_seqs.append(
self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(seq)))
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs,
pad_id=-1,
return_seq_len=True,
return_array=False)
data_processor_logger.debug(
f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
stop_seqs, stop_seqs_len = self.pad_batch_data(stop_seqs, pad_id=-1, return_seq_len=True, return_array=False)
data_processor_logger.debug(f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
return stop_seqs, stop_seqs_len
def process_logprob_response(self, token_ids, **kwargs):

View File

@@ -19,19 +19,14 @@
import os
import re
from shutil import copyfile
from typing import Dict, Optional, Tuple, List
from typing import Dict, List, Optional, Tuple
import numpy as np
import sentencepiece as spm
import paddle
from paddleformers.utils.log import logger
import sentencepiece as spm
from paddleformers.transformers import PretrainedTokenizer
from paddleformers.transformers.tokenizer_utils_base import (
PaddingStrategy,
TextInput,
)
from paddleformers.transformers.tokenizer_utils_base import PaddingStrategy, TextInput
from paddleformers.utils.log import logger
class ErnieBotTokenizer(PretrainedTokenizer):
@@ -47,7 +42,12 @@ class ErnieBotTokenizer(PretrainedTokenizer):
pretrained_init_configuration = {
"ernie-bot-10b": {},
}
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
model_input_names = [
"input_ids",
"position_ids",
"attention_mask",
"labels",
]
padding_side = "right"
def __init__(
@@ -222,9 +222,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)
]
escaped_special_toks = [re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
@@ -303,7 +301,12 @@ class ErnieBotTokenizer(PretrainedTokenizer):
elif not isinstance(attention_mask, np.ndarray):
raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
else:
attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64))
attention_mask = np.tril(
np.ones(
(len(required_input), len(required_input)),
dtype=np.int64,
)
)
attention_mask = np.expand_dims(attention_mask, axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)

View File

@@ -17,18 +17,23 @@
import os
import numpy as np
import re
from fastdeploy.input.mm_processor import DataProcessor, IDS_TYPE_FLAG
from fastdeploy.input.ernie_processor import ErnieProcessor
from fastdeploy.engine.request import Request
from fastdeploy.entrypoints.chat_utils import parse_chat_messages
from fastdeploy.input.ernie_processor import ErnieProcessor
from fastdeploy.input.mm_processor import IDS_TYPE_FLAG, DataProcessor
from fastdeploy.utils import data_processor_logger
class ErnieMoEVLProcessor(ErnieProcessor):
"""The processor class for ERNIE MoE VL models."""
def __init__(self, model_name_or_path, limit_mm_per_prompt=None, mm_processor_kwargs=None,
reasoning_parser_obj=None):
def __init__(
self,
model_name_or_path,
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
):
self.use_hf_tokenizer = False
if "merge_llm_model" in model_name_or_path:
@@ -41,7 +46,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.ernie_processor = DataProcessor(
tokenizer_name=tokenizer_path,
image_preprocessor_name=preprocessor_path,
**processor_kwargs
**processor_kwargs,
)
self.ernie_processor.eval()
self.image_patch_id = self.ernie_processor.image_patch_id
@@ -73,7 +78,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data"""
task = request.to_dict()
task['enable_thinking'] = kwargs.get("enable_thinking", True)
task["enable_thinking"] = kwargs.get("enable_thinking", True)
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
@@ -101,13 +106,14 @@ class ErnieMoEVLProcessor(ErnieProcessor):
"video_frames_sample": str,
"video_max_frames": int,
"video_min_frames": int,
"video_fps": int
"video_fps": int,
}
for key, value in kwargs.items():
if key in expected_types and not isinstance(value, expected_types[key]):
raise ValueError(
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}")
f"Invalid type for {key}: expected {expected_types[key].__name__}, got {type(value).__name__}"
)
return kwargs
@@ -117,11 +123,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def _parse_limits(self, limits):
"""解析多模态限制配置"""
DEFAULT_LIMITS = {
"image": 1,
"video": 1,
"audio": 1
}
DEFAULT_LIMITS = {"image": 1, "video": 1, "audio": 1}
if not limits:
return DEFAULT_LIMITS
@@ -141,10 +143,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
mm_data = item
else:
# 请求包含messages
mm_data = {
"image": [],
"video": []
}
mm_data = {"image": [], "video": []}
for message in item:
if isinstance(message.get("content"), list):
@@ -158,10 +157,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
if modality in self.limit_mm_per_prompt:
limit = self.limit_mm_per_prompt[modality]
if len(data) > limit:
raise ValueError(
f"Too many {modality} items in prompt, "
f"got {len(data)} but limit is {limit}"
)
raise ValueError(f"Too many {modality} items in prompt, " f"got {len(data)} but limit is {limit}")
def process_request_dict(self, request, max_model_len=None):
"""process the input data"""
@@ -200,13 +196,10 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request["multimodal_inputs"] = outputs
# 截断超过长度限制的prompt
if max_model_len is not None and len(
request['prompt_token_ids']) > max_model_len:
request['prompt_token_ids'] = request[
'prompt_token_ids'][:max_model_len - 1]
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None:
request["max_tokens"] = max(
1, max_model_len - len(request['prompt_token_ids']))
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
data_processor_logger.info(f"Processed request {request}")
return request

View File

@@ -14,10 +14,10 @@
# limitations under the License.
"""
from .process import DataProcessor, fancy_print, IDS_TYPE_FLAG
from .process import IDS_TYPE_FLAG, DataProcessor, fancy_print
__all__ = [
'DataProcessor',
'fancy_print',
'IDS_TYPE_FLAG',
"DataProcessor",
"fancy_print",
"IDS_TYPE_FLAG",
]

View File

@@ -17,4 +17,4 @@
from .get_image_preprocessor import get_image_preprocessor
from .image_preprocessor_adaptive import AdaptiveImageProcessor
__all__ = ['get_image_preprocessor', 'AdaptiveImageProcessor']
__all__ = ["get_image_preprocessor", "AdaptiveImageProcessor"]

View File

@@ -16,9 +16,10 @@
"""get image preprocessor"""
from .image_preprocessor_adaptive import AdaptiveImageProcessor
from fastdeploy.utils import data_processor_logger
from .image_preprocessor_adaptive import AdaptiveImageProcessor
def get_image_preprocessor(args):
"""

Some files were not shown because too many files have changed in this diff Show More