[XPU] support XDNN downloading function (#5365)

This commit is contained in:
Lucas
2025-12-05 11:16:45 +08:00
committed by GitHub
parent dd2e9a14c7
commit 7b0b6e470a

View File

@@ -21,6 +21,7 @@ Build and setup XPU custom ops for ERNIE Bot.
import os
import shutil
import subprocess
import tarfile
from pathlib import Path
import paddle
@@ -30,6 +31,39 @@ current_file = Path(__file__).resolve()
base_dir = os.path.join(current_file.parent, "src")
def download_and_extract(url, destination_directory):
"""
Download a .tar.gz file using wget to the destination directory
and extract its contents without renaming the downloaded file.
:param url: The URL of the .tar.gz file to download.
:param destination_directory: The directory where the file should be downloaded and extracted.
"""
os.makedirs(destination_directory, exist_ok=True)
filename = os.path.basename(url)
file_path = os.path.join(destination_directory, filename)
try:
print(f"Downloading: {url}")
subprocess.run(
["wget", "-O", file_path, url],
check=True,
)
print(f"Downloaded to: {file_path}")
print(f"Extracting: {file_path}")
with tarfile.open(file_path, "r:gz") as tar:
tar.extractall(path=destination_directory)
print(f"Extracted to: {destination_directory}")
os.remove(file_path)
print(f"Deleted downloaded file: {file_path}")
except subprocess.CalledProcessError as e:
print(f"Error downloading file: {e}")
except Exception as e:
print(f"Error extracting file: {e}")
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
"""
build xpu plugin
@@ -118,11 +152,50 @@ def xpu_setup_ops():
XDNN_PATH = os.getenv("XDNN_PATH")
if XDNN_PATH is None:
XDNN_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH)
XDNN_LIB_DIR = os.path.join(PADDLE_LIB_PATH)
else:
XDNN_INC_PATH = os.path.join(XDNN_PATH, "include")
XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so")
version_cmd = 'python -c "import paddle; print(paddle.version.xpu_xhpc())"'
try:
XHPC_VERSION = subprocess.check_output(version_cmd, shell=True).strip().decode()
print(f"Fetched XHPC_VERSION from paddle: {XHPC_VERSION}")
except Exception as e:
raise Exception(f"PaddlePaddle-xpu not installed, please install it first. {e}")
XHPC_URL = f"https://klx-sdk-release-public.su.bcebos.com/xhpc/{XHPC_VERSION}/xhpc-ubuntu2004_x86_64.tar.gz"
THIRD_PARTY_PATH = os.path.join(current_file.parent, "third_party")
XHPC_PATH = os.path.join(THIRD_PARTY_PATH, "xhpc-ubuntu2004_x86_64")
if os.path.exists(XHPC_PATH):
with open(os.path.join(XHPC_PATH, "version.txt")) as f:
date_line = [line.strip() for line in f.readlines() if "Date:" in line][0]
LOCAL_VERSION = f"dev/{date_line.split()[1]}"
if LOCAL_VERSION == XHPC_VERSION:
print("Local XHPC exists, skip downloading it again.")
else:
XHPC_UPDATE_POLICY_ENV = os.getenv("XHPC_UPDATE_POLICY")
if XHPC_UPDATE_POLICY_ENV is not None:
if XHPC_UPDATE_POLICY_ENV == "FORCE":
print("Forced update detected, downloading new XHPC.")
download_and_extract(XHPC_URL, THIRD_PARTY_PATH)
elif XHPC_UPDATE_POLICY_ENV == "SKIP":
print("Skipped updating XHPC.")
else:
raise Exception(
f"\033[91mInvalid value for environment variable XHPC_UPDATE_POLICY\033[0m: {XHPC_UPDATE_POLICY_ENV}, "
f"Valid environment values are FORCE or SKIP.",
)
else:
raise Exception(
f"\033[91mLocal XHPC version mismatch\033[0m, expected {XHPC_VERSION}, found {LOCAL_VERSION} in {XHPC_PATH}. "
f"\nPlease set environment XHPC_UPDATE_POLICY and rebuild FastDeploy. "
f"\nexport XHPC_UPDATE_POLICY=FORCE for downloading version({XHPC_VERSION}) with force. "
f"\nexport XHPC_UPDATE_POLICY=SKIP for using local version({LOCAL_VERSION}).",
)
else:
download_and_extract(XHPC_URL, THIRD_PARTY_PATH)
XDNN_PATH = os.path.join(XHPC_PATH, "xdnn")
XDNN_INC_PATH = os.path.join(XDNN_PATH, "include")
XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so")
XFA_PATH = os.getenv("XFA_PATH")
if XFA_PATH is None: