mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -12,13 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
Utility functions and classes for FastDeploy server operations.
|
||||
This module provides:
|
||||
- Custom logging handlers and formatters
|
||||
- File download and extraction utilities
|
||||
- Configuration parsing helpers
|
||||
- Various helper functions for server operations
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -33,19 +26,21 @@ import time
|
||||
from datetime import datetime
|
||||
from logging.handlers import BaseRotatingHandler
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypeVar, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class EngineError(Exception):
|
||||
"""Base exception class for engine-related errors.
|
||||
|
||||
Attributes:
|
||||
message (str): Human-readable error description
|
||||
error_code (int): HTTP-style error code (default: 400)
|
||||
"""
|
||||
"""Base exception class for engine errors"""
|
||||
|
||||
def __init__(self, message, error_code=400):
|
||||
super().__init__(message)
|
||||
@@ -53,13 +48,7 @@ class EngineError(Exception):
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom log formatter that adds color to console output.
|
||||
|
||||
Colors different log levels for better visibility:
|
||||
- WARNING: Yellow
|
||||
- ERROR: Red
|
||||
- CRITICAL: Red
|
||||
"""
|
||||
"""自定义日志格式器,用于控制台输出带颜色"""
|
||||
COLOR_CODES = {
|
||||
logging.WARNING: 33, # 黄色
|
||||
logging.ERROR: 31, # 红色
|
||||
@@ -77,10 +66,8 @@ class ColoredFormatter(logging.Formatter):
|
||||
|
||||
|
||||
class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||
"""Daily rotating file handler that supports multi-process logging.
|
||||
|
||||
Similar to `logging.TimedRotatingFileHandler` but designed to work safely
|
||||
in multi-process environments.
|
||||
"""
|
||||
like `logging.TimedRotatingFileHandler`, but this class support multi-process
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -90,19 +77,20 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||
delay=False,
|
||||
utc=False,
|
||||
**kwargs):
|
||||
"""Initialize the rotating file handler.
|
||||
"""
|
||||
初始化 RotatingFileHandler 对象。
|
||||
|
||||
Args:
|
||||
filename (str): Path to the log file (can be relative or absolute)
|
||||
backupCount (int, optional): Number of backup files to keep. Defaults to 0.
|
||||
encoding (str, optional): File encoding. Defaults to "utf-8".
|
||||
delay (bool, optional): Delay file opening until first write. Defaults to False.
|
||||
utc (bool, optional): Use UTC timezone for rollover. Defaults to False.
|
||||
**kwargs: Additional arguments passed to BaseRotatingHandler.
|
||||
filename (str): 日志文件的路径,可以是相对路径或绝对路径。
|
||||
backupCount (int, optional, default=0): 保存的备份文件数量,默认为 0,表示不保存备份文件。
|
||||
encoding (str, optional, default='utf-8'): 编码格式,默认为 'utf-8'。
|
||||
delay (bool, optional, default=False): 是否延迟写入,默认为 False,表示立即写入。
|
||||
utc (bool, optional, default=False): 是否使用 UTC 时区,默认为 False,表示不使用 UTC 时区。
|
||||
kwargs (dict, optional): 其他参数将被传递给 BaseRotatingHandler 类的 init 方法。
|
||||
|
||||
Raises:
|
||||
TypeError: If filename is not a string.
|
||||
ValueError: If backupCount is less than 0.
|
||||
TypeError: 如果 filename 不是 str 类型。
|
||||
ValueError: 如果 backupCount 小于等于 0。
|
||||
"""
|
||||
self.backup_count = backupCount
|
||||
self.utc = utc
|
||||
@@ -115,23 +103,16 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||
BaseRotatingHandler.__init__(self, filename, "a", encoding, delay)
|
||||
|
||||
def shouldRollover(self, record):
|
||||
"""Determine if a rollover should occur.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record being processed
|
||||
|
||||
Returns:
|
||||
bool: True if rollover should occur, False otherwise
|
||||
"""
|
||||
check scroll through the log
|
||||
"""
|
||||
if self.current_filename != self._compute_fn():
|
||||
return True
|
||||
return False
|
||||
|
||||
def doRollover(self):
|
||||
"""Perform the actual rollover operation.
|
||||
|
||||
Closes current file, creates new log file with current date suffix,
|
||||
and deletes any expired log files.
|
||||
"""
|
||||
scroll log
|
||||
"""
|
||||
if self.stream:
|
||||
self.stream.close()
|
||||
@@ -147,21 +128,15 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||
self.delete_expired_files()
|
||||
|
||||
def _compute_fn(self):
|
||||
"""Compute the current log filename with date suffix.
|
||||
|
||||
Returns:
|
||||
str: Filename with current date suffix (format: filename.YYYY-MM-DD)
|
||||
"""
|
||||
Calculate the log file name corresponding current time
|
||||
"""
|
||||
return self.base_filename + "." + time.strftime(
|
||||
self.suffix, time.localtime())
|
||||
|
||||
def _open(self):
|
||||
"""Open the current log file.
|
||||
|
||||
Also creates a symlink from the base filename to the current log file.
|
||||
|
||||
Returns:
|
||||
file object: The opened log file
|
||||
"""
|
||||
open new log file
|
||||
"""
|
||||
if self.encoding is None:
|
||||
stream = open(str(self.current_log_path), self.mode)
|
||||
@@ -184,10 +159,8 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||
return stream
|
||||
|
||||
def delete_expired_files(self):
|
||||
"""Delete expired log files based on backup count.
|
||||
|
||||
Only keeps the most recent backupCount files and deletes older ones.
|
||||
Does nothing if backupCount is <= 0.
|
||||
"""
|
||||
delete expired log files
|
||||
"""
|
||||
if self.backup_count <= 0:
|
||||
return
|
||||
@@ -215,21 +188,13 @@ def get_logger(name,
|
||||
file_name,
|
||||
without_formater=False,
|
||||
print_to_console=False):
|
||||
"""Create and configure a logger instance.
|
||||
|
||||
Args:
|
||||
name (str): Logger name
|
||||
file_name (str): Log file name (without path)
|
||||
without_formater (bool, optional): Skip adding formatter. Defaults to False.
|
||||
print_to_console (bool, optional): Also log to console. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Logger: Configured logger instance
|
||||
"""
|
||||
log_dir = os.getenv("FD_LOG_DIR", default="log")
|
||||
get logger
|
||||
"""
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
is_debug = int(os.getenv("FD_DEBUG", default="0"))
|
||||
is_debug = int(envs.FD_DEBUG)
|
||||
logger = logging.getLogger(name)
|
||||
if is_debug:
|
||||
logger.setLevel(level=logging.DEBUG)
|
||||
@@ -240,7 +205,7 @@ def get_logger(name,
|
||||
logger.removeHandler(handler)
|
||||
|
||||
LOG_FILE = "{0}/{1}".format(log_dir, file_name)
|
||||
backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", "7"))
|
||||
backup_count = int(envs.FD_LOG_BACKUP_COUNT)
|
||||
handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count)
|
||||
formatter = ColoredFormatter(
|
||||
"%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s"
|
||||
@@ -259,16 +224,8 @@ def get_logger(name,
|
||||
|
||||
|
||||
def str_to_datetime(date_string):
|
||||
"""Convert string to datetime object.
|
||||
|
||||
Supports both formats with and without microseconds.
|
||||
|
||||
Args:
|
||||
date_string (str): Date string in format "YYYY-MM-DD HH:MM:SS" or
|
||||
"YYYY-MM-DD HH:MM:SS.microseconds"
|
||||
|
||||
Returns:
|
||||
datetime: Parsed datetime object
|
||||
"""
|
||||
string to datetime class object
|
||||
"""
|
||||
if "." in date_string:
|
||||
return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
|
||||
@@ -277,14 +234,15 @@ def str_to_datetime(date_string):
|
||||
|
||||
|
||||
def datetime_diff(datetime_start, datetime_end):
|
||||
"""Calculate time difference between two datetime points.
|
||||
|
||||
"""
|
||||
Calculate the difference between two dates and times(s)
|
||||
|
||||
Args:
|
||||
datetime_start (Union[str, datetime.datetime]): Start time
|
||||
datetime_end (Union[str, datetime.datetime]): End time
|
||||
|
||||
datetime_start (Union[str, datetime.datetime]): start time
|
||||
datetime_end (Union[str, datetime.datetime]): end time
|
||||
|
||||
Returns:
|
||||
float: Time difference in seconds (always positive)
|
||||
float: date time difference(s)
|
||||
"""
|
||||
if isinstance(datetime_start, str):
|
||||
datetime_start = str_to_datetime(datetime_start)
|
||||
@@ -298,18 +256,7 @@ def datetime_diff(datetime_start, datetime_end):
|
||||
|
||||
|
||||
def download_file(url, save_path):
|
||||
"""Download a file from URL with progress bar.
|
||||
|
||||
Args:
|
||||
url (str): File URL to download
|
||||
save_path (str): Local path to save the file
|
||||
|
||||
Returns:
|
||||
bool: True if download succeeded
|
||||
|
||||
Raises:
|
||||
RuntimeError: If download fails (file is deleted on failure)
|
||||
"""
|
||||
"""Download file with progress bar"""
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
@@ -335,15 +282,7 @@ def download_file(url, save_path):
|
||||
|
||||
|
||||
def extract_tar(tar_path, output_dir):
|
||||
"""Extract contents of a tar file with progress tracking.
|
||||
|
||||
Args:
|
||||
tar_path (str): Path to tar file
|
||||
output_dir (str): Directory to extract files to
|
||||
|
||||
Raises:
|
||||
RuntimeError: If extraction fails
|
||||
"""
|
||||
"""Extract tar file with progress tracking"""
|
||||
try:
|
||||
with tarfile.open(tar_path) as tar:
|
||||
members = tar.getmembers()
|
||||
@@ -357,19 +296,19 @@ def extract_tar(tar_path, output_dir):
|
||||
|
||||
|
||||
def download_model(url, output_dir, temp_tar):
|
||||
"""Download and extract a model from URL.
|
||||
|
||||
"""
|
||||
下载模型,并将其解压到指定目录。
|
||||
|
||||
Args:
|
||||
url (str): Model file URL
|
||||
output_dir (str): Directory to save extracted model
|
||||
temp_tar (str): Temporary tar filename for download
|
||||
|
||||
url (str): 模型文件的URL地址。
|
||||
output_dir (str): 模型文件要保存的目录路径。
|
||||
temp_tar (str, optional): 临时保存模型文件的TAR包名称,默认为'temp.tar'.
|
||||
|
||||
Raises:
|
||||
Exception: If download or extraction fails
|
||||
RuntimeError: With link to model documentation if failure occurs
|
||||
|
||||
Note:
|
||||
Cleans up temporary files even if operation fails
|
||||
Exception: 如果下载或解压过程中出现任何错误,都会抛出Exception异常。
|
||||
|
||||
Returns:
|
||||
None - 无返回值,只是在下载和解压过程中进行日志输出和清理临时文件。
|
||||
"""
|
||||
try:
|
||||
temp_tar = os.path.join(output_dir, temp_tar)
|
||||
@@ -395,10 +334,8 @@ def download_model(url, output_dir, temp_tar):
|
||||
|
||||
|
||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
"""Extended ArgumentParser that supports loading parameters from YAML files.
|
||||
|
||||
Supports nested configuration structures in YAML that get flattened
|
||||
into command-line style arguments.
|
||||
"""
|
||||
扩展 argparse.ArgumentParser,支持从 YAML 文件加载参数。
|
||||
"""
|
||||
|
||||
def __init__(self, *args, config_arg='--config', sep='_', **kwargs):
|
||||
@@ -411,18 +348,6 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
help='Path to YAML config file')
|
||||
|
||||
def parse_args(self, args=None, namespace=None):
|
||||
"""Parse arguments with support for YAML configuration files.
|
||||
|
||||
Args:
|
||||
args: Argument strings to parse (default: sys.argv[1:])
|
||||
namespace: Namespace object to store attributes (default: new Namespace)
|
||||
|
||||
Returns:
|
||||
Namespace: populated namespace object
|
||||
|
||||
Note:
|
||||
Command line arguments override values from config file
|
||||
"""
|
||||
# 使用临时解析器解析出 --config 参数
|
||||
tmp_ns, remaining_args = self.tmp_parser.parse_known_args(args=args)
|
||||
config_path = tmp_ns.config
|
||||
@@ -455,14 +380,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
return super().parse_args(args=remaining_args, namespace=namespace)
|
||||
|
||||
def _flatten_dict(self, d):
|
||||
"""Flatten nested dictionary into single level with joined keys.
|
||||
|
||||
Args:
|
||||
d (dict): Nested dictionary to flatten
|
||||
|
||||
Returns:
|
||||
dict: Flattened dictionary with keys joined by separator
|
||||
"""
|
||||
"""将嵌套字典展平为单层字典,键由分隔符连接"""
|
||||
|
||||
def _flatten(d, parent_key=''):
|
||||
items = []
|
||||
@@ -478,34 +396,14 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
|
||||
def resolve_obj_from_strname(strname: str):
|
||||
"""Import and return an object from its full dotted path string.
|
||||
|
||||
Args:
|
||||
strname (str): Full dotted path to object (e.g. "module.submodule.Class")
|
||||
|
||||
Returns:
|
||||
object: The imported object
|
||||
|
||||
Example:
|
||||
>>> resolve_obj_from_strname("os.path.join")
|
||||
<function join at 0x...>
|
||||
"""
|
||||
module_name, obj_name = strname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
|
||||
def check_unified_ckpt(model_dir):
|
||||
"""Check if directory contains a PaddleNLP unified checkpoint.
|
||||
|
||||
Args:
|
||||
model_dir (str): Path to model directory
|
||||
|
||||
Returns:
|
||||
bool: True if valid unified checkpoint, False otherwise
|
||||
|
||||
Raises:
|
||||
Exception: If checkpoint appears corrupted
|
||||
"""
|
||||
Check if the model is a PaddleNLP unified checkpoint
|
||||
"""
|
||||
model_files = list()
|
||||
all_files = os.listdir(model_dir)
|
||||
@@ -538,24 +436,16 @@ def check_unified_ckpt(model_dir):
|
||||
|
||||
|
||||
def get_host_ip():
|
||||
"""Get host machine's IP address.
|
||||
|
||||
Returns:
|
||||
str: Host IP address
|
||||
"""
|
||||
Get host IP address
|
||||
"""
|
||||
ip = socket.gethostbyname(socket.gethostname())
|
||||
return ip
|
||||
|
||||
|
||||
def is_port_available(host, port):
|
||||
"""Check if a network port is available for binding.
|
||||
|
||||
Args:
|
||||
host (str): Hostname or IP address
|
||||
port (int): Port number
|
||||
|
||||
Returns:
|
||||
bool: True if port is available, False if already in use
|
||||
"""
|
||||
Check the port is available
|
||||
"""
|
||||
import errno
|
||||
import socket
|
||||
@@ -570,7 +460,112 @@ def is_port_available(host, port):
|
||||
return True
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
"""
|
||||
Singleton decorator for a class.
|
||||
"""
|
||||
instances = {}
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
|
||||
return get_instance
|
||||
|
||||
|
||||
def print_gpu_memory_use(gpu_id: int, title: str) -> None:
|
||||
""" Print memory usage """
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
|
||||
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
print(
|
||||
f"\n{title}:",
|
||||
f"\n\tDevice Total memory: {meminfo.total}",
|
||||
f"\n\tDevice Used memory: {meminfo.used}",
|
||||
f"\n\tDevice Free memory: {meminfo.free}",
|
||||
)
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def none_or_str(value):
|
||||
"""
|
||||
Keep parameters None, not the string "None".
|
||||
"""
|
||||
return None if value == "None" else value
|
||||
|
||||
|
||||
def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
"""
|
||||
Download pretrained model from AIStudio automatically
|
||||
"""
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
try:
|
||||
repo_id = model_name_or_path
|
||||
if repo_id.lower().strip().startswith("baidu"):
|
||||
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
|
||||
local_path = envs.FD_MODEL_CACHE
|
||||
if local_path is None:
|
||||
local_path = f'{os.getenv("HOME")}/{repo_id}'
|
||||
snapshot_download(repo_id=repo_id,
|
||||
revision=revision,
|
||||
local_dir=local_path)
|
||||
model_name_or_path = local_path
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"The setting model_name_or_path:{model_name_or_path} is not exist."
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: Union[type[T], tuple[type[T], ...]],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[list[T]]:
|
||||
"""
|
||||
Check if the value is a list of specified type.
|
||||
|
||||
Args:
|
||||
value: The value to check.
|
||||
typ: The type or tuple of types to check against.
|
||||
check: The check mode, either "first" or "all".
|
||||
|
||||
Returns:
|
||||
Whether the value is a list of specified type.
|
||||
"""
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if check == "first":
|
||||
return len(value) == 0 or isinstance(value[0], typ)
|
||||
elif check == "all":
|
||||
return all(isinstance(v, typ) for v in value)
|
||||
|
||||
assert_never(check)
|
||||
|
||||
|
||||
llm_logger = get_logger("fastdeploy", "fastdeploy.log")
|
||||
data_processor_logger = get_logger("data_processor", "data_processor.log")
|
||||
scheduler_logger = get_logger("scheduler", "scheduler.log")
|
||||
api_server_logger = get_logger("api_server", "api_server.log")
|
||||
console_logger = get_logger("console", "console.log", print_to_console=True)
|
||||
spec_logger = get_logger("speculate", "speculate.log")
|
||||
|
||||
Reference in New Issue
Block a user