Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -12,13 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Utility functions and classes for FastDeploy server operations.
This module provides:
- Custom logging handlers and formatters
- File download and extraction utilities
- Configuration parsing helpers
- Various helper functions for server operations
"""
import argparse
@@ -33,19 +26,21 @@ import time
from datetime import datetime
from logging.handlers import BaseRotatingHandler
from pathlib import Path
from typing import Literal, TypeVar, Union
import requests
import yaml
from aistudio_sdk.snapshot_download import snapshot_download
from tqdm import tqdm
from typing_extensions import TypeIs, assert_never
from fastdeploy import envs
T = TypeVar("T")
class EngineError(Exception):
"""Base exception class for engine-related errors.
Attributes:
message (str): Human-readable error description
error_code (int): HTTP-style error code (default: 400)
"""
"""Base exception class for engine errors"""
def __init__(self, message, error_code=400):
super().__init__(message)
@@ -53,13 +48,7 @@ class EngineError(Exception):
class ColoredFormatter(logging.Formatter):
"""Custom log formatter that adds color to console output.
Colors different log levels for better visibility:
- WARNING: Yellow
- ERROR: Red
- CRITICAL: Red
"""
"""自定义日志格式器,用于控制台输出带颜色"""
COLOR_CODES = {
logging.WARNING: 33, # 黄色
logging.ERROR: 31, # 红色
@@ -77,10 +66,8 @@ class ColoredFormatter(logging.Formatter):
class DailyRotatingFileHandler(BaseRotatingHandler):
"""Daily rotating file handler that supports multi-process logging.
Similar to `logging.TimedRotatingFileHandler` but designed to work safely
in multi-process environments.
"""
like `logging.TimedRotatingFileHandler`, but this class support multi-process
"""
def __init__(self,
@@ -90,19 +77,20 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
delay=False,
utc=False,
**kwargs):
"""Initialize the rotating file handler.
"""
初始化 RotatingFileHandler 对象。
Args:
filename (str): Path to the log file (can be relative or absolute)
backupCount (int, optional): Number of backup files to keep. Defaults to 0.
encoding (str, optional): File encoding. Defaults to "utf-8".
delay (bool, optional): Delay file opening until first write. Defaults to False.
utc (bool, optional): Use UTC timezone for rollover. Defaults to False.
**kwargs: Additional arguments passed to BaseRotatingHandler.
filename (str): 日志文件的路径,可以是相对路径或绝对路径。
backupCount (int, optional, default=0): 保存的备份文件数量,默认为 0表示不保存备份文件。
encoding (str, optional, default='utf-8'): 编码格式,默认为 'utf-8'
delay (bool, optional, default=False): 是否延迟写入,默认为 False表示立即写入。
utc (bool, optional, default=False): 是否使用 UTC 时区,默认为 False表示不使用 UTC 时区。
kwargs (dict, optional): 其他参数将被传递给 BaseRotatingHandler 类的 init 方法。
Raises:
TypeError: If filename is not a string.
ValueError: If backupCount is less than 0.
TypeError: 如果 filename 不是 str 类型。
ValueError: 如果 backupCount 小于等于 0
"""
self.backup_count = backupCount
self.utc = utc
@@ -115,23 +103,16 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
BaseRotatingHandler.__init__(self, filename, "a", encoding, delay)
def shouldRollover(self, record):
"""Determine if a rollover should occur.
Args:
record (LogRecord): The log record being processed
Returns:
bool: True if rollover should occur, False otherwise
"""
check scroll through the log
"""
if self.current_filename != self._compute_fn():
return True
return False
def doRollover(self):
"""Perform the actual rollover operation.
Closes current file, creates new log file with current date suffix,
and deletes any expired log files.
"""
scroll log
"""
if self.stream:
self.stream.close()
@@ -147,21 +128,15 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
self.delete_expired_files()
def _compute_fn(self):
"""Compute the current log filename with date suffix.
Returns:
str: Filename with current date suffix (format: filename.YYYY-MM-DD)
"""
Calculate the log file name corresponding current time
"""
return self.base_filename + "." + time.strftime(
self.suffix, time.localtime())
def _open(self):
"""Open the current log file.
Also creates a symlink from the base filename to the current log file.
Returns:
file object: The opened log file
"""
open new log file
"""
if self.encoding is None:
stream = open(str(self.current_log_path), self.mode)
@@ -184,10 +159,8 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
return stream
def delete_expired_files(self):
"""Delete expired log files based on backup count.
Only keeps the most recent backupCount files and deletes older ones.
Does nothing if backupCount is <= 0.
"""
delete expired log files
"""
if self.backup_count <= 0:
return
@@ -215,21 +188,13 @@ def get_logger(name,
file_name,
without_formater=False,
print_to_console=False):
"""Create and configure a logger instance.
Args:
name (str): Logger name
file_name (str): Log file name (without path)
without_formater (bool, optional): Skip adding formatter. Defaults to False.
print_to_console (bool, optional): Also log to console. Defaults to False.
Returns:
Logger: Configured logger instance
"""
log_dir = os.getenv("FD_LOG_DIR", default="log")
get logger
"""
log_dir = envs.FD_LOG_DIR
if not os.path.exists(log_dir):
os.mkdir(log_dir)
is_debug = int(os.getenv("FD_DEBUG", default="0"))
is_debug = int(envs.FD_DEBUG)
logger = logging.getLogger(name)
if is_debug:
logger.setLevel(level=logging.DEBUG)
@@ -240,7 +205,7 @@ def get_logger(name,
logger.removeHandler(handler)
LOG_FILE = "{0}/{1}".format(log_dir, file_name)
backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", "7"))
backup_count = int(envs.FD_LOG_BACKUP_COUNT)
handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count)
formatter = ColoredFormatter(
"%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s"
@@ -259,16 +224,8 @@ def get_logger(name,
def str_to_datetime(date_string):
"""Convert string to datetime object.
Supports both formats with and without microseconds.
Args:
date_string (str): Date string in format "YYYY-MM-DD HH:MM:SS" or
"YYYY-MM-DD HH:MM:SS.microseconds"
Returns:
datetime: Parsed datetime object
"""
string to datetime class object
"""
if "." in date_string:
return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
@@ -277,14 +234,15 @@ def str_to_datetime(date_string):
def datetime_diff(datetime_start, datetime_end):
"""Calculate time difference between two datetime points.
"""
Calculate the difference between two dates and times(s)
Args:
datetime_start (Union[str, datetime.datetime]): Start time
datetime_end (Union[str, datetime.datetime]): End time
datetime_start (Union[str, datetime.datetime]): start time
datetime_end (Union[str, datetime.datetime]): end time
Returns:
float: Time difference in seconds (always positive)
float: date time difference(s)
"""
if isinstance(datetime_start, str):
datetime_start = str_to_datetime(datetime_start)
@@ -298,18 +256,7 @@ def datetime_diff(datetime_start, datetime_end):
def download_file(url, save_path):
"""Download a file from URL with progress bar.
Args:
url (str): File URL to download
save_path (str): Local path to save the file
Returns:
bool: True if download succeeded
Raises:
RuntimeError: If download fails (file is deleted on failure)
"""
"""Download file with progress bar"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
@@ -335,15 +282,7 @@ def download_file(url, save_path):
def extract_tar(tar_path, output_dir):
"""Extract contents of a tar file with progress tracking.
Args:
tar_path (str): Path to tar file
output_dir (str): Directory to extract files to
Raises:
RuntimeError: If extraction fails
"""
"""Extract tar file with progress tracking"""
try:
with tarfile.open(tar_path) as tar:
members = tar.getmembers()
@@ -357,19 +296,19 @@ def extract_tar(tar_path, output_dir):
def download_model(url, output_dir, temp_tar):
"""Download and extract a model from URL.
"""
下载模型,并将其解压到指定目录。
Args:
url (str): Model file URL
output_dir (str): Directory to save extracted model
temp_tar (str): Temporary tar filename for download
url (str): 模型文件的URL地址。
output_dir (str): 模型文件要保存的目录路径。
temp_tar (str, optional): 临时保存模型文件的TAR包名称默认为'temp.tar'.
Raises:
Exception: If download or extraction fails
RuntimeError: With link to model documentation if failure occurs
Note:
Cleans up temporary files even if operation fails
Exception: 如果下载或解压过程中出现任何错误都会抛出Exception异常。
Returns:
None - 无返回值,只是在下载和解压过程中进行日志输出和清理临时文件。
"""
try:
temp_tar = os.path.join(output_dir, temp_tar)
@@ -395,10 +334,8 @@ def download_model(url, output_dir, temp_tar):
class FlexibleArgumentParser(argparse.ArgumentParser):
"""Extended ArgumentParser that supports loading parameters from YAML files.
Supports nested configuration structures in YAML that get flattened
into command-line style arguments.
"""
扩展 argparse.ArgumentParser支持从 YAML 文件加载参数。
"""
def __init__(self, *args, config_arg='--config', sep='_', **kwargs):
@@ -411,18 +348,6 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
help='Path to YAML config file')
def parse_args(self, args=None, namespace=None):
"""Parse arguments with support for YAML configuration files.
Args:
args: Argument strings to parse (default: sys.argv[1:])
namespace: Namespace object to store attributes (default: new Namespace)
Returns:
Namespace: populated namespace object
Note:
Command line arguments override values from config file
"""
# 使用临时解析器解析出 --config 参数
tmp_ns, remaining_args = self.tmp_parser.parse_known_args(args=args)
config_path = tmp_ns.config
@@ -455,14 +380,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return super().parse_args(args=remaining_args, namespace=namespace)
def _flatten_dict(self, d):
"""Flatten nested dictionary into single level with joined keys.
Args:
d (dict): Nested dictionary to flatten
Returns:
dict: Flattened dictionary with keys joined by separator
"""
"""将嵌套字典展平为单层字典,键由分隔符连接"""
def _flatten(d, parent_key=''):
items = []
@@ -478,34 +396,14 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
def resolve_obj_from_strname(strname: str):
"""Import and return an object from its full dotted path string.
Args:
strname (str): Full dotted path to object (e.g. "module.submodule.Class")
Returns:
object: The imported object
Example:
>>> resolve_obj_from_strname("os.path.join")
<function join at 0x...>
"""
module_name, obj_name = strname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
def check_unified_ckpt(model_dir):
"""Check if directory contains a PaddleNLP unified checkpoint.
Args:
model_dir (str): Path to model directory
Returns:
bool: True if valid unified checkpoint, False otherwise
Raises:
Exception: If checkpoint appears corrupted
"""
Check if the model is a PaddleNLP unified checkpoint
"""
model_files = list()
all_files = os.listdir(model_dir)
@@ -538,24 +436,16 @@ def check_unified_ckpt(model_dir):
def get_host_ip():
"""Get host machine's IP address.
Returns:
str: Host IP address
"""
Get host IP address
"""
ip = socket.gethostbyname(socket.gethostname())
return ip
def is_port_available(host, port):
"""Check if a network port is available for binding.
Args:
host (str): Hostname or IP address
port (int): Port number
Returns:
bool: True if port is available, False if already in use
"""
Check the port is available
"""
import errno
import socket
@@ -570,7 +460,112 @@ def is_port_available(host, port):
return True
def singleton(cls):
"""
Singleton decorator for a class.
"""
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
def print_gpu_memory_use(gpu_id: int, title: str) -> None:
""" Print memory usage """
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
print(
f"\n{title}:",
f"\n\tDevice Total memory: {meminfo.total}",
f"\n\tDevice Used memory: {meminfo.used}",
f"\n\tDevice Free memory: {meminfo.free}",
)
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def none_or_str(value):
"""
Keep parameters None, not the string "None".
"""
return None if value == "None" else value
def retrive_model_from_server(model_name_or_path, revision="master"):
"""
Download pretrained model from AIStudio automatically
"""
if os.path.exists(model_name_or_path):
return model_name_or_path
try:
repo_id = model_name_or_path
if repo_id.lower().strip().startswith("baidu"):
repo_id = "PaddlePaddle" + repo_id.strip()[5:]
local_path = envs.FD_MODEL_CACHE
if local_path is None:
local_path = f'{os.getenv("HOME")}/{repo_id}'
snapshot_download(repo_id=repo_id,
revision=revision,
local_dir=local_path)
model_name_or_path = local_path
except Exception:
raise Exception(
f"The setting model_name_or_path:{model_name_or_path} is not exist."
)
return model_name_or_path
def is_list_of(
value: object,
typ: Union[type[T], tuple[type[T], ...]],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
"""
Check if the value is a list of specified type.
Args:
value: The value to check.
typ: The type or tuple of types to check against.
check: The check mode, either "first" or "all".
Returns:
Whether the value is a list of specified type.
"""
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
llm_logger = get_logger("fastdeploy", "fastdeploy.log")
data_processor_logger = get_logger("data_processor", "data_processor.log")
scheduler_logger = get_logger("scheduler", "scheduler.log")
api_server_logger = get_logger("api_server", "api_server.log")
console_logger = get_logger("console", "console.log", print_to_console=True)
spec_logger = get_logger("speculate", "speculate.log")