mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] support XDNN downloading function (#5365)
This commit is contained in:
@@ -21,6 +21,7 @@ Build and setup XPU custom ops for ERNIE Bot.
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
@@ -30,6 +31,39 @@ current_file = Path(__file__).resolve()
|
||||
base_dir = os.path.join(current_file.parent, "src")
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
Download a .tar.gz file using wget to the destination directory
|
||||
and extract its contents without renaming the downloaded file.
|
||||
|
||||
:param url: The URL of the .tar.gz file to download.
|
||||
:param destination_directory: The directory where the file should be downloaded and extracted.
|
||||
"""
|
||||
os.makedirs(destination_directory, exist_ok=True)
|
||||
|
||||
filename = os.path.basename(url)
|
||||
file_path = os.path.join(destination_directory, filename)
|
||||
|
||||
try:
|
||||
print(f"Downloading: {url}")
|
||||
subprocess.run(
|
||||
["wget", "-O", file_path, url],
|
||||
check=True,
|
||||
)
|
||||
print(f"Downloaded to: {file_path}")
|
||||
|
||||
print(f"Extracting: {file_path}")
|
||||
with tarfile.open(file_path, "r:gz") as tar:
|
||||
tar.extractall(path=destination_directory)
|
||||
print(f"Extracted to: {destination_directory}")
|
||||
os.remove(file_path)
|
||||
print(f"Deleted downloaded file: {file_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error downloading file: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
|
||||
"""
|
||||
build xpu plugin
|
||||
@@ -118,11 +152,50 @@ def xpu_setup_ops():
|
||||
|
||||
XDNN_PATH = os.getenv("XDNN_PATH")
|
||||
if XDNN_PATH is None:
|
||||
XDNN_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH)
|
||||
XDNN_LIB_DIR = os.path.join(PADDLE_LIB_PATH)
|
||||
else:
|
||||
XDNN_INC_PATH = os.path.join(XDNN_PATH, "include")
|
||||
XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so")
|
||||
version_cmd = 'python -c "import paddle; print(paddle.version.xpu_xhpc())"'
|
||||
try:
|
||||
XHPC_VERSION = subprocess.check_output(version_cmd, shell=True).strip().decode()
|
||||
print(f"Fetched XHPC_VERSION from paddle: {XHPC_VERSION}")
|
||||
except Exception as e:
|
||||
raise Exception(f"PaddlePaddle-xpu not installed, please install it first. {e}")
|
||||
|
||||
XHPC_URL = f"https://klx-sdk-release-public.su.bcebos.com/xhpc/{XHPC_VERSION}/xhpc-ubuntu2004_x86_64.tar.gz"
|
||||
THIRD_PARTY_PATH = os.path.join(current_file.parent, "third_party")
|
||||
XHPC_PATH = os.path.join(THIRD_PARTY_PATH, "xhpc-ubuntu2004_x86_64")
|
||||
|
||||
if os.path.exists(XHPC_PATH):
|
||||
with open(os.path.join(XHPC_PATH, "version.txt")) as f:
|
||||
date_line = [line.strip() for line in f.readlines() if "Date:" in line][0]
|
||||
LOCAL_VERSION = f"dev/{date_line.split()[1]}"
|
||||
if LOCAL_VERSION == XHPC_VERSION:
|
||||
print("Local XHPC exists, skip downloading it again.")
|
||||
else:
|
||||
XHPC_UPDATE_POLICY_ENV = os.getenv("XHPC_UPDATE_POLICY")
|
||||
if XHPC_UPDATE_POLICY_ENV is not None:
|
||||
if XHPC_UPDATE_POLICY_ENV == "FORCE":
|
||||
print("Forced update detected, downloading new XHPC.")
|
||||
download_and_extract(XHPC_URL, THIRD_PARTY_PATH)
|
||||
elif XHPC_UPDATE_POLICY_ENV == "SKIP":
|
||||
print("Skipped updating XHPC.")
|
||||
else:
|
||||
raise Exception(
|
||||
f"\033[91mInvalid value for environment variable XHPC_UPDATE_POLICY\033[0m: {XHPC_UPDATE_POLICY_ENV}, "
|
||||
f"Valid environment values are FORCE or SKIP.",
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"\033[91mLocal XHPC version mismatch\033[0m, expected {XHPC_VERSION}, found {LOCAL_VERSION} in {XHPC_PATH}. "
|
||||
f"\nPlease set environment XHPC_UPDATE_POLICY and rebuild FastDeploy. "
|
||||
f"\nexport XHPC_UPDATE_POLICY=FORCE for downloading version({XHPC_VERSION}) with force. "
|
||||
f"\nexport XHPC_UPDATE_POLICY=SKIP for using local version({LOCAL_VERSION}).",
|
||||
)
|
||||
else:
|
||||
download_and_extract(XHPC_URL, THIRD_PARTY_PATH)
|
||||
|
||||
XDNN_PATH = os.path.join(XHPC_PATH, "xdnn")
|
||||
|
||||
XDNN_INC_PATH = os.path.join(XDNN_PATH, "include")
|
||||
XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so")
|
||||
|
||||
XFA_PATH = os.getenv("XFA_PATH")
|
||||
if XFA_PATH is None:
|
||||
|
||||
Reference in New Issue
Block a user