mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
【FastDeploy CLI】collect-env subcommand (#4044)
* collect-env subcommand * trigger ci --------- Co-authored-by: K11OntheBoat <your_email@example.com>
This commit is contained in:
@@ -34,7 +34,7 @@ After FastDeploy is launched, it supports continuous monitoring of the FastDeplo
|
||||
| `fastdeploy:available_gpu_block_num` | Gauge | Number of available gpu blocks in cache, including prefix caching blocks that are not officially released | Count |
|
||||
| `fastdeploy:free_gpu_block_num` | Gauge | Number of free blocks in cache | Count |
|
||||
| `fastdeploy:max_gpu_block_num` | Gauge | Number of total blocks determined when service started| Count |
|
||||
| `available_gpu_resource` | Gauge | Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num | Count |
|
||||
| `fastdeploy:available_gpu_resource` | Gauge | Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num | Count |
|
||||
| `fastdeploy:requests_number` | Counter | Total number of requests received | Count |
|
||||
| `fastdeploy:send_cache_failed_num` | Counter | Total number of failures of sending cache | Count |
|
||||
| `fastdeploy:first_token_latency` | Gauge | Latest time to generate first token in seconds | Seconds |
|
||||
|
@@ -34,7 +34,7 @@
|
||||
| `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的GPU块数量(包含尚未正式释放的前缀缓存块)| 个 |
|
||||
| `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 |
|
||||
| `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 |
|
||||
| `available_gpu_resource` | Gauge | 可用块占比,即可用GPU块数量 / 最大GPU块数量| 个 |
|
||||
| `fastdeploy:available_gpu_resource` | Gauge | 可用块占比,即可用GPU块数量 / 最大GPU块数量| 个 |
|
||||
| `fastdeploy:requests_number` | Counter | 已接收的请求总数 | 个 |
|
||||
| `fastdeploy:send_cache_failed_num` | Counter | 发送缓存失败的总次数 | 个 |
|
||||
| `fastdeploy:first_token_latency` | Gauge | 最近一次生成首token耗时 | 秒 |
|
||||
|
783
fastdeploy/collect_env.py
Normal file
783
fastdeploy/collect_env.py
Normal file
@@ -0,0 +1,783 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
# This file is modified from https://github.com/vllm-project/vllm/collect_env.py
|
||||
|
||||
import datetime
|
||||
import locale
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Unlike the rest of the PyTorch this file must be python2 compliant.
|
||||
# This script outputs relevant system environment info
|
||||
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
|
||||
from collections import namedtuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from fastdeploy.envs import environment_variables
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import paddle
|
||||
|
||||
PADDLE_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
PADDLE_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple(
|
||||
"SystemEnv",
|
||||
[
|
||||
"torch_version",
|
||||
"is_debug_build",
|
||||
"cuda_compiled_version",
|
||||
"paddle_version",
|
||||
"cuda_compiled_version_paddle",
|
||||
"gcc_version",
|
||||
"clang_version",
|
||||
"cmake_version",
|
||||
"os",
|
||||
"libc_version",
|
||||
"python_version",
|
||||
"python_platform",
|
||||
"is_cuda_available",
|
||||
"cuda_runtime_version",
|
||||
"cuda_module_loading",
|
||||
"nvidia_driver_version",
|
||||
"nvidia_gpu_models",
|
||||
"cudnn_version",
|
||||
"pip_version", # 'pip' or 'pip3'
|
||||
"pip_packages",
|
||||
"conda_packages",
|
||||
"is_xnnpack_available",
|
||||
"cpu_info",
|
||||
"fastdeploy_version", # fastdploy specific field
|
||||
"fastdeploy_build_flags", # fastdploy specific field
|
||||
"gpu_topo", # fastdploy specific field
|
||||
"env_vars",
|
||||
],
|
||||
)
|
||||
|
||||
DEFAULT_CONDA_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"cudatoolkit",
|
||||
"soumith",
|
||||
"mkl",
|
||||
"magma",
|
||||
"triton",
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"mypy",
|
||||
"flake8",
|
||||
"triton",
|
||||
"optree",
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
}
|
||||
|
||||
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
try:
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == "win32":
|
||||
enc = "oem"
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
if command == "nvidia-smi topo -m":
|
||||
# don't remove the leading whitespace of `nvidia-smi topo -m`
|
||||
# because they are meaningful
|
||||
output = output.rstrip()
|
||||
else:
|
||||
output = output.strip()
|
||||
err = raw_err.decode(enc)
|
||||
return rc, output, err.strip()
|
||||
|
||||
except FileNotFoundError:
|
||||
cmd_str = command if isinstance(command, str) else command[0]
|
||||
return 127, "", f"Command not found: {cmd_str}"
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Run command using run_lambda, returns the first regex match if it exists."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def run_and_return_first_line(run_lambda, command):
|
||||
"""Run command using run_lambda and returns first line if output is not empty."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out.split("\n")[0]
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_CONDA_PATTERNS
|
||||
conda = os.environ.get("CONDA_EXE", "conda")
|
||||
out = run_and_read_all(run_lambda, [conda, "list"])
|
||||
if out is None:
|
||||
return out
|
||||
|
||||
return "\n".join(
|
||||
line for line in out.splitlines() if not line.startswith("#") and any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "clang --version", r"clang version (.*)")
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == "darwin":
|
||||
cmd = "kextstat | grep -i cuda"
|
||||
return run_and_parse_first_match(run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]")
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == "darwin":
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
gcnArch = ""
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r" \(UUID: .+?\)")
|
||||
rc, out, _ = run_lambda(smi + " -L")
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, "", out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, "System32", "where")
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == "darwin":
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
|
||||
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get("CUDNN_LIBRARY")
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split("\n"):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
if not files_set:
|
||||
return None
|
||||
# Alphabetize the result because the order is non-deterministic otherwise
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = "\n".join(files)
|
||||
return "Probably one of the following:\n{}".format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = "nvidia-smi"
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
|
||||
legacy_path = os.path.join(program_files_root, "NVIDIA Corporation", "NVSMI", smi)
|
||||
new_path = os.path.join(system_root, "System32", smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
smi = '"{}"'.format(candidate_smi)
|
||||
break
|
||||
return smi
|
||||
|
||||
|
||||
def get_fastdeploy_version():
|
||||
import pkg_resources
|
||||
|
||||
version = os.environ.get("FASTDEPLOY_VERSION")
|
||||
if version:
|
||||
return version
|
||||
|
||||
try:
|
||||
return pkg_resources.get_distribution("fastdeploy").version
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
result = subprocess.run(["pip", "show", "fastdeploy"], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("Version:"):
|
||||
return line.split(":")[1].strip()
|
||||
except:
|
||||
pass
|
||||
|
||||
return "unknown (could not determine version)"
|
||||
|
||||
|
||||
def summarize_fastdeploy_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return "CUDA Archs: {};".format(os.getenv("FD_BUILDING_ARCS", "[]"))
|
||||
|
||||
|
||||
def get_gpu_topo(run_lambda):
|
||||
output = None
|
||||
|
||||
if get_platform() == "linux":
|
||||
output = run_and_read_all(run_lambda, "nvidia-smi topo -m")
|
||||
if output is None:
|
||||
output = run_and_read_all(run_lambda, "rocm-smi --showtopo")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# example outputs of CPU infos
|
||||
# * linux
|
||||
# Architecture: x86_64
|
||||
# CPU op-mode(s): 32-bit, 64-bit
|
||||
# Address sizes: 46 bits physical, 48 bits virtual
|
||||
# Byte Order: Little Endian
|
||||
# CPU(s): 128
|
||||
# On-line CPU(s) list: 0-127
|
||||
# Vendor ID: GenuineIntel
|
||||
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# CPU family: 6
|
||||
# Model: 106
|
||||
# Thread(s) per core: 2
|
||||
# Core(s) per socket: 32
|
||||
# Socket(s): 2
|
||||
# Stepping: 6
|
||||
# BogoMIPS: 5799.78
|
||||
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
|
||||
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
|
||||
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
|
||||
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
|
||||
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
|
||||
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
|
||||
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
|
||||
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
|
||||
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
|
||||
# Virtualization features:
|
||||
# Hypervisor vendor: KVM
|
||||
# Virtualization type: full
|
||||
# Caches (sum of all):
|
||||
# L1d: 3 MiB (64 instances)
|
||||
# L1i: 2 MiB (64 instances)
|
||||
# L2: 80 MiB (64 instances)
|
||||
# L3: 108 MiB (2 instances)
|
||||
# NUMA:
|
||||
# NUMA node(s): 2
|
||||
# NUMA node0 CPU(s): 0-31,64-95
|
||||
# NUMA node1 CPU(s): 32-63,96-127
|
||||
# Vulnerabilities:
|
||||
# Itlb multihit: Not affected
|
||||
# L1tf: Not affected
|
||||
# Mds: Not affected
|
||||
# Meltdown: Not affected
|
||||
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
|
||||
# Retbleed: Not affected
|
||||
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
|
||||
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
|
||||
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
|
||||
# Srbds: Not affected
|
||||
# Tsx async abort: Not affected
|
||||
# * win32
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU0
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
#
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU1
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, "", ""
|
||||
if get_platform() == "linux":
|
||||
rc, out, err = run_lambda("lscpu")
|
||||
elif get_platform() == "win32":
|
||||
rc, out, err = run_lambda(
|
||||
"wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE"
|
||||
)
|
||||
elif get_platform() == "darwin":
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = "None"
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
cpu_info = err
|
||||
return cpu_info
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif sys.platform.startswith("win32"):
|
||||
return "win32"
|
||||
elif sys.platform.startswith("cygwin"):
|
||||
return "cygwin"
|
||||
elif sys.platform.startswith("darwin"):
|
||||
return "darwin"
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic")
|
||||
findstr_cmd = os.path.join(system_root, "System32", "findstr")
|
||||
return run_and_read_all(run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd))
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "lsb_release -a", r"Description:\t(.*)")
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"')
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
|
||||
platform = get_platform()
|
||||
|
||||
if platform == "win32" or platform == "cygwin":
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == "darwin":
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return "macOS {} ({})".format(version, machine())
|
||||
|
||||
if platform == "linux":
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
return "{} ({})".format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
|
||||
if get_platform() != "linux":
|
||||
return "N/A"
|
||||
return "-".join(platform.libc_ver())
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_PIP_PATTERNS
|
||||
|
||||
def run_with_pip():
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
pip_spec = importlib.util.find_spec("pip")
|
||||
pip_available = pip_spec is not None
|
||||
except ImportError:
|
||||
pip_available = False
|
||||
|
||||
if pip_available:
|
||||
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
|
||||
elif os.environ.get("UV") is not None:
|
||||
print("uv is set")
|
||||
cmd = ["uv", "pip", "list", "--format=freeze"]
|
||||
else:
|
||||
raise RuntimeError("Could not collect pip list output (pip or uv module not available)")
|
||||
|
||||
out = run_and_read_all(run_lambda, cmd)
|
||||
return "\n".join(line for line in out.splitlines() if any(name in line for name in patterns))
|
||||
|
||||
pip_version = "pip3" if sys.version[0] == "3" else "pip"
|
||||
out = run_with_pip()
|
||||
return pip_version, out
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get("CUDA_MODULE_LOADING", "")
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
try:
|
||||
import torch.backends.xnnpack
|
||||
|
||||
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
except:
|
||||
return "N/A"
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_vars():
|
||||
env_vars = ""
|
||||
secret_terms = ("secret", "token", "api", "access", "password")
|
||||
report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", "OMP_", "MKL_", "NVIDIA")
|
||||
for k, v in os.environ.items():
|
||||
if any(term in k.lower() for term in secret_terms):
|
||||
continue
|
||||
if k in environment_variables:
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
if k.startswith(report_prefix):
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def get_env_info():
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
version_str = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
torch_cuda_available_str = str(torch.cuda.is_available())
|
||||
cuda_version_str = torch.version.cuda
|
||||
else:
|
||||
version_str = debug_mode_str = torch_cuda_available_str = cuda_version_str = "N/A"
|
||||
|
||||
if PADDLE_AVAILABLE:
|
||||
paddle_version_str = paddle.__version__
|
||||
paddle_cuda_available_str = str(torch.cuda.is_available())
|
||||
paddle_cuda_version_str = str(paddle.version.cuda())
|
||||
else:
|
||||
version_str = paddle_cuda_available_str = cuda_version_str = "N/A"
|
||||
|
||||
if torch_cuda_available_str == "True" or paddle_cuda_available_str == "True":
|
||||
cuda_available_str = "True"
|
||||
else:
|
||||
cuda_available_str = "False"
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
fastdeploy_version = get_fastdeploy_version()
|
||||
fastdeploy_build_flags = summarize_fastdeploy_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
paddle_version=paddle_version_str,
|
||||
cuda_compiled_version_paddle=paddle_cuda_version_str,
|
||||
python_version="{} ({}-bit runtime)".format(sys_version, sys.maxsize.bit_length() + 1),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
cuda_runtime_version=get_running_cuda_version(run_lambda),
|
||||
cuda_module_loading=get_cuda_module_loading_config(),
|
||||
nvidia_gpu_models=get_gpu_info(run_lambda),
|
||||
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
|
||||
cudnn_version=get_cudnn_version(run_lambda),
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=conda_packages,
|
||||
os=get_os(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
fastdeploy_version=fastdeploy_version,
|
||||
fastdeploy_build_flags=fastdeploy_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
env_vars=get_env_vars(),
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
==============================
|
||||
System Info
|
||||
==============================
|
||||
OS : {os}
|
||||
GCC version : {gcc_version}
|
||||
Clang version : {clang_version}
|
||||
CMake version : {cmake_version}
|
||||
Libc version : {libc_version}
|
||||
|
||||
==============================
|
||||
PyTorch Info
|
||||
==============================
|
||||
PyTorch version : {torch_version}
|
||||
Is debug build : {is_debug_build}
|
||||
CUDA used to build PyTorch : {cuda_compiled_version}
|
||||
|
||||
==============================
|
||||
Paddle Info
|
||||
==============================
|
||||
Paddle version : {paddle_version}
|
||||
CUDA used to build paddle : {cuda_compiled_version_paddle}
|
||||
|
||||
==============================
|
||||
Python Environment
|
||||
==============================
|
||||
Python version : {python_version}
|
||||
Python platform : {python_platform}
|
||||
|
||||
==============================
|
||||
CUDA / GPU Info
|
||||
==============================
|
||||
Is CUDA available : {is_cuda_available}
|
||||
CUDA runtime version : {cuda_runtime_version}
|
||||
CUDA_MODULE_LOADING set to : {cuda_module_loading}
|
||||
GPU models and configuration : {nvidia_gpu_models}
|
||||
Nvidia driver version : {nvidia_driver_version}
|
||||
cuDNN version : {cudnn_version}
|
||||
Is XNNPACK available : {is_xnnpack_available}
|
||||
|
||||
==============================
|
||||
CPU Info
|
||||
==============================
|
||||
{cpu_info}
|
||||
|
||||
==============================
|
||||
Versions of relevant libraries
|
||||
==============================
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
""".strip()
|
||||
|
||||
# both the above code and the following code use `strip()` to
|
||||
# remove leading/trailing whitespaces, so we need to add a newline
|
||||
# in between to separate the two sections
|
||||
env_info_fmt += "\n\n"
|
||||
|
||||
env_info_fmt += """
|
||||
==============================
|
||||
FastDeploy Info
|
||||
==============================
|
||||
FastDeply Version : {fastdeploy_version}
|
||||
FastDeply Build Flags:
|
||||
{fastdeploy_build_flags}
|
||||
GPU Topology:
|
||||
{gpu_topo}
|
||||
|
||||
==============================
|
||||
Environment Variables
|
||||
==============================
|
||||
{env_vars}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
|
||||
def replace_nones(dct, replacement="Could not collect"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true="Yes", false="No"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
elif dct[key] is False:
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag="[prepend]"):
|
||||
lines = text.split("\n")
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return "\n".join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement="No relevant packages"):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split("\n")) > 1:
|
||||
return "\n{}\n".format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(envinfo.nvidia_gpu_models)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
"cuda_runtime_version",
|
||||
"nvidia_gpu_models",
|
||||
"nvidia_driver_version",
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
|
||||
all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None for field in dynamic_cuda_fields)
|
||||
if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = "No CUDA"
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict["cuda_compiled_version"] = "None"
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
|
||||
# Replace all None objects with 'Could not collect'
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
|
||||
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict["pip_packages"]:
|
||||
mutable_dict["pip_packages"] = prepend(mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version))
|
||||
if mutable_dict["conda_packages"]:
|
||||
mutable_dict["conda_packages"] = prepend(mutable_dict["conda_packages"], "[conda] ")
|
||||
mutable_dict["cpu_info"] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_env_info():
|
||||
return pretty_str(get_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
print("Collecting environment information...")
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if TORCH_AVAILABLE and hasattr(torch, "utils") and hasattr(torch.utils, "_crash_handler"):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
msg = (
|
||||
"\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time)
|
||||
+ "if this is related to your bug please include it when you file a report ***"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
51
fastdeploy/entrypoints/cli/collect_env.py
Normal file
51
fastdeploy/entrypoints/cli/collect_env.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
# This file is modified from https://github.com/vllm-project/vllm/entrypoints/cli/collect_env.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
|
||||
from fastdeploy.collect_env import main as collect_env_main
|
||||
from fastdeploy.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from fastdeploy.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class CollectEnvSubcommand(CLISubcommand):
|
||||
"""The `collect-env` subcommand for the FastDeploy CLI."""
|
||||
|
||||
name = "collect-env"
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Collect information about the environment."""
|
||||
collect_env_main()
|
||||
|
||||
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
return subparsers.add_parser(
|
||||
"collect-env",
|
||||
help="Start collecting environment information.",
|
||||
description="Start collecting environment information.",
|
||||
usage="vllm collect-env",
|
||||
)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [CollectEnvSubcommand()]
|
41
tests/entrypoints/cli/test_collect_env_conmmand.py
Normal file
41
tests/entrypoints/cli/test_collect_env_conmmand.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import unittest
|
||||
from argparse import Namespace, _SubParsersAction
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.entrypoints.cli.collect_env import CollectEnvSubcommand, cmd_init
|
||||
|
||||
|
||||
class TestCollectEnvSubcommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.subcommand = CollectEnvSubcommand()
|
||||
|
||||
def test_name_property(self):
|
||||
self.assertEqual(self.subcommand.name, "collect-env")
|
||||
|
||||
@patch("fastdeploy.entrypoints.cli.collect_env.collect_env_main")
|
||||
def test_cmd(self, mock_collect_env_main):
|
||||
args = Namespace()
|
||||
self.subcommand.cmd(args)
|
||||
mock_collect_env_main.assert_called_once()
|
||||
|
||||
def test_subparser_init(self):
|
||||
mock_subparsers = MagicMock(spec=_SubParsersAction)
|
||||
parser = self.subcommand.subparser_init(mock_subparsers)
|
||||
print(parser)
|
||||
mock_subparsers.add_parser.assert_called_once_with(
|
||||
"collect-env",
|
||||
help="Start collecting environment information.",
|
||||
description="Start collecting environment information.",
|
||||
usage="vllm collect-env",
|
||||
)
|
||||
|
||||
|
||||
class TestCmdInit(unittest.TestCase):
|
||||
def test_cmd_init(self):
|
||||
subcommands = cmd_init()
|
||||
self.assertEqual(len(subcommands), 1)
|
||||
self.assertIsInstance(subcommands[0], CollectEnvSubcommand)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
488
tests/entrypoints/cli/test_collect_env_script.py
Normal file
488
tests/entrypoints/cli/test_collect_env_script.py
Normal file
@@ -0,0 +1,488 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import fastdeploy.collect_env as collect_env
|
||||
|
||||
|
||||
class TestCollectEnv(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.run_lambda = MagicMock()
|
||||
self.run_lambda.return_value = (0, "test output", "")
|
||||
|
||||
def test_run(self):
|
||||
result = collect_env.run("echo test")
|
||||
self.assertIsInstance(result, tuple)
|
||||
|
||||
def test_run_nvidia(self):
|
||||
result = collect_env.run("nvidia-smi topo -m")
|
||||
self.assertIsInstance(result, tuple)
|
||||
|
||||
def test_run_and_read_all(self):
|
||||
result = collect_env.run_and_read_all(self.run_lambda, "test command")
|
||||
self.assertEqual(result, "test output")
|
||||
self.run_lambda.return_value = (1, "version 1.0", "")
|
||||
result = collect_env.run_and_read_all(self.run_lambda, "test command")
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_run_and_parse_first_match(self):
|
||||
self.run_lambda.return_value = (0, "version 1.0", "")
|
||||
result = collect_env.run_and_parse_first_match(self.run_lambda, "test command", r"version (.*)")
|
||||
self.assertEqual(result, "1.0")
|
||||
self.run_lambda.return_value = (1, "version 1.0", "")
|
||||
result = collect_env.run_and_parse_first_match(self.run_lambda, "test command", r"version (.*)")
|
||||
self.assertEqual(result, None)
|
||||
self.run_lambda.return_value = (0, "version 1.0", "")
|
||||
result = collect_env.run_and_parse_first_match(self.run_lambda, "test command", r"sadsad")
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_run_and_return_first_line(self):
|
||||
self.run_lambda.return_value = (0, "line1\nline2", "")
|
||||
result = collect_env.run_and_return_first_line(self.run_lambda, "test command")
|
||||
self.assertEqual(result, "line1")
|
||||
self.run_lambda.return_value = (1, "line1\nline2", "")
|
||||
result = collect_env.run_and_return_first_line(self.run_lambda, "test command")
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_get_conda_packages(self):
|
||||
with patch("fastdeploy.collect_env.run_and_read_all") as mock_read:
|
||||
mock_read.return_value = "package1\npackage2"
|
||||
result = collect_env.get_conda_packages(self.run_lambda)
|
||||
self.assertIsNotNone(result)
|
||||
with patch("fastdeploy.collect_env.run_and_read_all") as mock_read:
|
||||
mock_read.return_value = None
|
||||
result = collect_env.get_conda_packages(self.run_lambda)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_gcc_version(self):
|
||||
with patch("fastdeploy.collect_env.run_and_parse_first_match") as mock_parse:
|
||||
mock_parse.return_value = "1.0"
|
||||
result = collect_env.get_gcc_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
|
||||
def test_get_clang_version(self):
|
||||
with patch("fastdeploy.collect_env.run_and_parse_first_match") as mock_parse:
|
||||
mock_parse.return_value = "1.0"
|
||||
result = collect_env.get_clang_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
|
||||
def test_get_cmake_version(self):
|
||||
with patch("fastdeploy.collect_env.run_and_parse_first_match") as mock_parse:
|
||||
mock_parse.return_value = "1.0"
|
||||
result = collect_env.get_cmake_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
|
||||
def test_get_nvidia_driver_version(self):
|
||||
with patch("fastdeploy.collect_env.run_and_parse_first_match") as mock_parse:
|
||||
mock_parse.return_value = "1.0"
|
||||
result = collect_env.get_nvidia_driver_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run_and_parse_first_match", return_value="1.0"),
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="darwin"),
|
||||
):
|
||||
result = collect_env.get_nvidia_driver_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
|
||||
def test_get_gpu_info(self):
|
||||
with patch("fastdeploy.collect_env.TORCH_AVAILABLE", False):
|
||||
result = collect_env.get_gpu_info(self.run_lambda)
|
||||
self.assertIsNotNone(result)
|
||||
with (
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="darwin"),
|
||||
patch("fastdeploy.collect_env.TORCH_AVAILABLE", True),
|
||||
patch("fastdeploy.collect_env.torch", create=True),
|
||||
):
|
||||
result = collect_env.get_gpu_info(self.run_lambda)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_get_running_cuda_version(self):
|
||||
with patch("fastdeploy.collect_env.run_and_parse_first_match") as mock_parse:
|
||||
mock_parse.return_value = "1.0"
|
||||
result = collect_env.get_running_cuda_version(self.run_lambda)
|
||||
self.assertEqual(result, "1.0")
|
||||
|
||||
def test_get_cudnn_version(self):
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run") as mock_run,
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="linux"),
|
||||
):
|
||||
mock_run.return_value = (0, "/usr/local/cuda/lib64/libcudnn.so.8.4.1", "")
|
||||
result = collect_env.get_cudnn_version(self.run_lambda)
|
||||
self.assertEqual(result, None)
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run") as mock_run,
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="win32"),
|
||||
):
|
||||
mock_run.return_value = (0, "/usr/local/cuda/lib64/libcudnn.so.8.4.1", "")
|
||||
result = collect_env.get_cudnn_version(self.run_lambda)
|
||||
self.assertEqual(result, None)
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run") as mock_run,
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="darwin"),
|
||||
):
|
||||
mock_run.return_value = (0, "/usr/local/cuda/lib64/libcudnn.so.8.4.1", "")
|
||||
result = collect_env.get_cudnn_version(self.run_lambda)
|
||||
self.assertEqual(result, None)
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run") as mock_run,
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="darwin"),
|
||||
):
|
||||
mock_run.return_value = (2, "/usr/local/cuda/lib64/libcudnn.so.8.4.1\n/usr/xxx", "")
|
||||
self.run_lambda.return_value = (2, "version 1.0", "")
|
||||
result = collect_env.get_cudnn_version(self.run_lambda)
|
||||
self.assertEqual(result, None)
|
||||
|
||||
with (
|
||||
patch("os.path.realpath", side_effect=lambda x: f"/real_path/to/{x.split('/')[-1]}"),
|
||||
patch("os.path.isfile", return_value=True),
|
||||
):
|
||||
self.run_lambda.return_value = (
|
||||
0,
|
||||
"/usr/local/cuda/lib/libcudnn.so.8\n/usr/local/cuda/lib/libcudnn.so.8.2.1",
|
||||
"",
|
||||
)
|
||||
cudnn_version = collect_env.get_cudnn_version(self.run_lambda)
|
||||
# 验证返回结果是预期的多行字符串
|
||||
expected_output = (
|
||||
"Probably one of the following:\n/real_path/to/libcudnn.so.8\n/real_path/to/libcudnn.so.8.2.1"
|
||||
)
|
||||
self.assertEqual(cudnn_version, expected_output)
|
||||
|
||||
def test_get_nvidia_smi(self):
|
||||
result = collect_env.get_nvidia_smi()
|
||||
self.assertIsNotNone(result)
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="win32"):
|
||||
result = collect_env.get_nvidia_smi()
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_get_fastdeploy_version(self):
|
||||
with patch("fastdeploy.collect_env.os.environ.get", return_value="1.0"):
|
||||
result = collect_env.get_fastdeploy_version()
|
||||
self.assertEqual(result, "1.0")
|
||||
with patch("fastdeploy.collect_env.os.environ.get", return_value=None):
|
||||
result = collect_env.get_fastdeploy_version()
|
||||
self.assertIsNotNone(result)
|
||||
with patch("pkg_resources.get_distribution", side_effect=Exception("Package not found")):
|
||||
with patch("fastdeploy.collect_env.os.environ.get", return_value=None):
|
||||
with patch("subprocess.run", return_value=None):
|
||||
result = collect_env.get_fastdeploy_version()
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_summarize_fastdeploy_build_flags(self):
|
||||
result = collect_env.summarize_fastdeploy_build_flags()
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_get_gpu_topo(self):
|
||||
result = collect_env.get_gpu_topo(self.run_lambda)
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_get_cpu_info(self):
|
||||
self.run_lambda.return_value = (0, "Architecture: x86_64\nModel name: Intel(R) Xeon(R) CPU", "")
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="linux"):
|
||||
cpu_info = collect_env.get_cpu_info(self.run_lambda)
|
||||
self.assertIn("x86_64", cpu_info)
|
||||
self.assertIn("Intel(R) Xeon(R)", cpu_info)
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="win32"):
|
||||
cpu_info = collect_env.get_cpu_info(self.run_lambda)
|
||||
self.assertIn("x86_64", cpu_info)
|
||||
self.assertIn("Intel(R) Xeon(R)", cpu_info)
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="darwin"):
|
||||
cpu_info = collect_env.get_cpu_info(self.run_lambda)
|
||||
self.assertIn("x86_64", cpu_info)
|
||||
self.assertIn("Intel(R) Xeon(R)", cpu_info)
|
||||
self.run_lambda.return_value = (1, "Architecture: x86_64\nModel name: Intel(R) Xeon(R) CPU", "err")
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="darwin"):
|
||||
cpu_info = collect_env.get_cpu_info(self.run_lambda)
|
||||
self.assertIn("err", cpu_info)
|
||||
|
||||
def test_get_platform(self):
|
||||
with patch("sys.platform", "linux"):
|
||||
self.assertEqual(collect_env.get_platform(), "linux")
|
||||
with patch("sys.platform", "win32"):
|
||||
self.assertEqual(collect_env.get_platform(), "win32")
|
||||
with patch("sys.platform", "cygwin"):
|
||||
self.assertEqual(collect_env.get_platform(), "cygwin")
|
||||
with patch("sys.platform", "darwin"):
|
||||
self.assertEqual(collect_env.get_platform(), "darwin")
|
||||
|
||||
def test_get_os_linux_lsb_success(self):
|
||||
"""测试 Linux 环境下,lsb_release 命令成功。"""
|
||||
with patch("sys.platform", "linux"):
|
||||
# 模拟 get_lsb_version 成功返回
|
||||
with patch("fastdeploy.collect_env.get_lsb_version", return_value="Ubuntu 20.04 LTS"):
|
||||
# 模拟 platform.machine 成功返回
|
||||
with patch("platform.machine", return_value="x86_64"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "Ubuntu 20.04 LTS (x86_64)")
|
||||
|
||||
def test_get_os_linux_lsb_fail_check_release_success(self):
|
||||
"""测试 Linux 环境下,lsb_release 失败,但 /etc/*-release 成功。"""
|
||||
with patch("sys.platform", "linux"):
|
||||
# 模拟 get_lsb_version 失败
|
||||
with patch("fastdeploy.collect_env.get_lsb_version", return_value=None):
|
||||
# 模拟 check_release_file 成功返回
|
||||
with patch("fastdeploy.collect_env.check_release_file", return_value="CentOS Linux 8"):
|
||||
with patch("platform.machine", return_value="x86_64"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "CentOS Linux 8 (x86_64)")
|
||||
|
||||
def test_get_os_linux_all_fail(self):
|
||||
"""测试 Linux 环境下,所有方法都失败。"""
|
||||
with patch("sys.platform", "linux"):
|
||||
with patch("fastdeploy.collect_env.get_lsb_version", return_value=None):
|
||||
with patch("fastdeploy.collect_env.check_release_file", return_value=None):
|
||||
with patch("platform.machine", return_value="x86_64"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "linux (x86_64)")
|
||||
|
||||
def test_get_os_windows_success(self):
|
||||
"""测试 Windows 环境下,命令成功。"""
|
||||
with patch("sys.platform", "win32"):
|
||||
# 模拟 get_windows_version 成功返回
|
||||
with patch("fastdeploy.collect_env.get_windows_version", return_value="Microsoft Windows 10"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "Microsoft Windows 10")
|
||||
|
||||
def test_get_os_windows_fail(self):
|
||||
"""测试 Windows 环境下,命令失败。"""
|
||||
with patch("sys.platform", "win32"):
|
||||
# 模拟 get_windows_version 失败返回 None
|
||||
with patch("fastdeploy.collect_env.get_windows_version", return_value=None):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_os_macos_success(self):
|
||||
"""测试 macOS 环境下,命令成功。"""
|
||||
with patch("sys.platform", "darwin"):
|
||||
# 模拟 get_mac_version 成功返回
|
||||
with patch("fastdeploy.collect_env.get_mac_version", return_value="12.3.1"):
|
||||
with patch("platform.machine", return_value="arm64"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "macOS 12.3.1 (arm64)")
|
||||
|
||||
def test_get_os_macos_fail(self):
|
||||
"""测试 macOS 环境下,命令失败。"""
|
||||
with patch("sys.platform", "darwin"):
|
||||
# 模拟 get_mac_version 失败返回 None
|
||||
with patch("fastdeploy.collect_env.get_mac_version", return_value=None):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_os_unknown_platform(self):
|
||||
"""测试未知平台的情况。"""
|
||||
with patch("sys.platform", "solaris"):
|
||||
result = collect_env.get_os(self.run_lambda)
|
||||
self.assertEqual(result, "solaris")
|
||||
|
||||
def test_get_python_platform(self):
|
||||
"""测试 get_python_platform 函数返回正确值。"""
|
||||
with patch("platform.platform", return_value="Linux-5.15.0-76-generic-x86_64-with-glibc2.35"):
|
||||
result = collect_env.get_python_platform()
|
||||
self.assertEqual(result, "Linux-5.15.0-76-generic-x86_64-with-glibc2.35")
|
||||
|
||||
def test_get_libc_version_linux_success(self):
|
||||
"""测试在 Linux 环境下成功获取 libc 版本。"""
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="linux"):
|
||||
with patch("platform.libc_ver", return_value=("glibc", "2.35")):
|
||||
result = collect_env.get_libc_version()
|
||||
self.assertEqual(result, "glibc-2.35")
|
||||
|
||||
def test_get_libc_version_non_linux(self):
|
||||
"""测试在非 Linux 环境下返回 'N/A'。"""
|
||||
with patch("fastdeploy.collect_env.get_platform", return_value="win32"):
|
||||
# 确保 platform.libc_ver() 不被调用,或者即使调用也不会影响结果
|
||||
with patch("platform.libc_ver") as mock_libc_ver:
|
||||
result = collect_env.get_libc_version()
|
||||
self.assertEqual(result, "N/A")
|
||||
mock_libc_ver.assert_not_called()
|
||||
|
||||
def test_get_pip_packages_no_pip_or_uv(self):
|
||||
"""Test that a RuntimeError is raised when neither pip nor uv are available."""
|
||||
with patch.dict("os.environ", {}, clear=True), patch("importlib.util.find_spec", return_value=None):
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
collect_env.get_pip_packages(self.run_lambda)
|
||||
self.assertIn("Could not collect pip list output", str(cm.exception))
|
||||
|
||||
def test_get_pip_packages_success(self):
|
||||
"""Test that the pip module is available and a list of packages is returned."""
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run_and_read_all") as mock_run_and_read_all,
|
||||
patch("importlib.util.find_spec", return_value=True),
|
||||
):
|
||||
|
||||
mock_run_and_read_all.return_value = "torch==2.0.0\nregex==2023.1.1\nnumpy==1.25.0"
|
||||
|
||||
pip_version, packages = collect_env.get_pip_packages(self.run_lambda)
|
||||
|
||||
self.assertEqual(pip_version, "pip3" if sys.version[0] == "3" else "pip")
|
||||
self.assertIn("torch==2.0.0", packages)
|
||||
self.assertIn("numpy==1.25.0", packages)
|
||||
self.assertNotIn("regex", packages)
|
||||
|
||||
def test_get_pip_packages_uv_available(self):
|
||||
"""Test that uv is used when pip is not available but the UV environment variable is set."""
|
||||
with (
|
||||
patch.dict("os.environ", {"UV": "1"}),
|
||||
patch("fastdeploy.collect_env.run_and_read_all") as mock_run_and_read_all,
|
||||
patch("importlib.util.find_spec", return_value=False),
|
||||
):
|
||||
|
||||
mock_run_and_read_all.return_value = "torch==2.0.0\nregex==2023.1.1\nnumpy==1.25.0"
|
||||
|
||||
pip_version, packages = collect_env.get_pip_packages(self.run_lambda)
|
||||
|
||||
self.assertIsNotNone(packages)
|
||||
|
||||
def test_get_pip_packages_command_fail(self):
|
||||
"""Test that an empty string is returned when the pip command fails."""
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run_and_read_all", return_value="\n"),
|
||||
patch("importlib.util.find_spec", return_value=True),
|
||||
):
|
||||
|
||||
pip_version, packages = collect_env.get_pip_packages(self.run_lambda)
|
||||
self.assertEqual(packages, "")
|
||||
|
||||
def test_get_pip_packages_custom_patterns(self):
|
||||
"""Test that the function correctly filters packages based on custom patterns."""
|
||||
with (
|
||||
patch("fastdeploy.collect_env.run_and_read_all") as mock_run_and_read_all,
|
||||
patch("importlib.util.find_spec", return_value=True),
|
||||
):
|
||||
|
||||
mock_run_and_read_all.return_value = "torch==2.0.0\nmy-custom-lib==1.0.0\nrequests==2.28.1"
|
||||
|
||||
custom_patterns = {"my-custom-lib", "requests"}
|
||||
pip_version, packages = collect_env.get_pip_packages(self.run_lambda, patterns=custom_patterns)
|
||||
|
||||
self.assertIn("my-custom-lib==1.0.0", packages)
|
||||
self.assertIn("requests==2.28.1", packages)
|
||||
self.assertNotIn("torch", packages)
|
||||
|
||||
def test_xnnpack_available_with_torch(self):
|
||||
with patch("fastdeploy.collect_env.TORCH_AVAILABLE", True), patch("fastdeploy.collect_env.torch", create=True):
|
||||
result = collect_env.is_xnnpack_available()
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
def test_xnnpack_not_available_without_torch(self):
|
||||
"""测试 torch 不可用时,返回 'N/A'。"""
|
||||
with patch("fastdeploy.collect_env.TORCH_AVAILABLE", False):
|
||||
result = collect_env.is_xnnpack_available()
|
||||
self.assertEqual(result, "N/A")
|
||||
|
||||
def test_get_env_vars_with_relevant_vars(self):
|
||||
"""测试正确收集相关环境变量。"""
|
||||
# 准备一个包含各种类型环境变量的字典
|
||||
mock_env = {
|
||||
"TORCH_DEBUG": "1",
|
||||
"CUDA_ARCHS": "7.5",
|
||||
"SOME_OTHER_VAR": "value",
|
||||
"MY_API_KEY": "secret_key",
|
||||
"FASTDEPLOY_MODEL_DIR": "/path/to/model", # 假设这个在 environment_variables 中
|
||||
}
|
||||
|
||||
# 模拟 environment_variables 列表
|
||||
with patch.dict(os.environ, mock_env, clear=True):
|
||||
with patch(
|
||||
"fastdeploy.collect_env.environment_variables",
|
||||
["FASTDEPLOY_MODEL_DIR"],
|
||||
):
|
||||
env_vars_string = collect_env.get_env_vars()
|
||||
self.assertIn("TORCH_DEBUG=1", env_vars_string)
|
||||
self.assertIn("CUDA_ARCHS=7.5", env_vars_string)
|
||||
self.assertIn("FASTDEPLOY_MODEL_DIR=/path/to/model", env_vars_string)
|
||||
self.assertNotIn("SOME_OTHER_VAR", env_vars_string)
|
||||
self.assertNotIn("MY_API_KEY", env_vars_string)
|
||||
|
||||
def test_get_cuda_config_with_both_vars_set(self):
|
||||
with patch("fastdeploy.collect_env.TORCH_AVAILABLE", True), patch("fastdeploy.collect_env.torch", create=True):
|
||||
mock_env = {
|
||||
"CUDA_MODULE_LOADING": "xxx",
|
||||
}
|
||||
with patch.dict(os.environ, mock_env, clear=True):
|
||||
result = collect_env.get_cuda_module_loading_config()
|
||||
self.assertEqual(result, "xxx")
|
||||
|
||||
def test_get_cuda_config_with_no_vars_set(self):
|
||||
"""测试两个环境变量都未设置。"""
|
||||
with patch("fastdeploy.collect_env.TORCH_AVAILABLE", False):
|
||||
result = collect_env.get_cuda_module_loading_config()
|
||||
self.assertEqual(result, "N/A")
|
||||
|
||||
def test_get_env_info_full(self):
|
||||
|
||||
# 使用嵌套的 with patch 语句来模拟所有依赖函数的返回值
|
||||
with (
|
||||
patch("fastdeploy.collect_env.get_platform", return_value="linux"),
|
||||
patch("fastdeploy.collect_env.get_os", return_value="Ubuntu 20.04 (x86_64)"),
|
||||
patch("fastdeploy.collect_env.get_python_platform", return_value="Python 3.8.10"),
|
||||
patch("fastdeploy.collect_env.get_cuda_module_loading_config", return_value="CUDA_DEVICE_VISIBLE=0"),
|
||||
patch("fastdeploy.collect_env.get_libc_version", return_value="glibc-2.31"),
|
||||
patch("fastdeploy.collect_env.get_gcc_version", return_value="9.3.0"),
|
||||
patch("fastdeploy.collect_env.get_clang_version", return_value=None),
|
||||
patch("fastdeploy.collect_env.get_cmake_version", return_value="3.18.4"),
|
||||
patch("fastdeploy.collect_env.get_nvidia_driver_version", return_value="470.82.00"),
|
||||
patch("fastdeploy.collect_env.get_cudnn_version", return_value="8.2.1"),
|
||||
patch("fastdeploy.collect_env.get_running_cuda_version", return_value="11.4"),
|
||||
patch("fastdeploy.collect_env.get_gpu_info", return_value="GeForce RTX 3080"),
|
||||
patch("fastdeploy.collect_env.get_cpu_info", return_value="Intel(R) Core(TM) i9-10900K"),
|
||||
patch("fastdeploy.collect_env.is_xnnpack_available", return_value="True"),
|
||||
patch("fastdeploy.collect_env.get_fastdeploy_version", return_value="1.0.0"),
|
||||
patch("fastdeploy.collect_env.get_conda_packages", return_value="numpy==1.22.0\ntorch==1.11.0"),
|
||||
patch("fastdeploy.collect_env.get_pip_packages", return_value=("pip3", "requests==2.27.1\nscipy==1.8.0")),
|
||||
patch("fastdeploy.collect_env.get_env_vars", return_value="CUDA_VISIBLE_DEVICES=0"),
|
||||
patch("fastdeploy.collect_env.get_gpu_topo", return_value="GPU-Direct disabled"),
|
||||
patch("fastdeploy.collect_env.torch", create=True, __version__="1.0.0"),
|
||||
):
|
||||
|
||||
info_string = collect_env.get_env_info()
|
||||
|
||||
self.assertIsNotNone(info_string)
|
||||
|
||||
def test_get_env_info_all_na(self):
|
||||
|
||||
with (
|
||||
patch("fastdeploy.collect_env.get_python_platform", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_cuda_module_loading_config", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_libc_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_gcc_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_clang_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_cmake_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_nvidia_driver_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_cudnn_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_running_cuda_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_gpu_info", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_cpu_info", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.is_xnnpack_available", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_fastdeploy_version", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.get_conda_packages", return_value=None),
|
||||
patch("fastdeploy.collect_env.get_pip_packages", return_value=("N/A", "N/A")),
|
||||
patch("fastdeploy.collect_env.get_env_vars", return_value=""),
|
||||
patch("fastdeploy.collect_env.get_gpu_topo", return_value="N/A"),
|
||||
patch("fastdeploy.collect_env.TORCH_AVAILABLE", return_value=False),
|
||||
patch("fastdeploy.collect_env.PADDLE_AVAILABLE", return_value=False),
|
||||
patch("fastdeploy.collect_env.torch", create=True, __version__="1.0.0"),
|
||||
):
|
||||
|
||||
info_string = collect_env.get_env_info()
|
||||
|
||||
self.assertIsNotNone(info_string)
|
||||
|
||||
def test_main_with_collect(self):
|
||||
captured_output = io.StringIO()
|
||||
with (
|
||||
patch("sys.stdout", new=captured_output),
|
||||
patch("fastdeploy.collect_env.torch", create=True, __version__="1.0.0"),
|
||||
):
|
||||
collect_env.main()
|
||||
output = captured_output.getvalue()
|
||||
expected_message = "Collecting environment information"
|
||||
self.assertIn(expected_message, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user