diff --git a/custom_ops/xpu_ops/setup_ops.py b/custom_ops/xpu_ops/setup_ops.py index fa31f5a8e..d29e6290f 100755 --- a/custom_ops/xpu_ops/setup_ops.py +++ b/custom_ops/xpu_ops/setup_ops.py @@ -21,6 +21,7 @@ Build and setup XPU custom ops for ERNIE Bot. import os import shutil import subprocess +import tarfile from pathlib import Path import paddle @@ -30,6 +31,39 @@ current_file = Path(__file__).resolve() base_dir = os.path.join(current_file.parent, "src") +def download_and_extract(url, destination_directory): + """ + Download a .tar.gz file using wget to the destination directory + and extract its contents without renaming the downloaded file. + + :param url: The URL of the .tar.gz file to download. + :param destination_directory: The directory where the file should be downloaded and extracted. + """ + os.makedirs(destination_directory, exist_ok=True) + + filename = os.path.basename(url) + file_path = os.path.join(destination_directory, filename) + + try: + print(f"Downloading: {url}") + subprocess.run( + ["wget", "-O", file_path, url], + check=True, + ) + print(f"Downloaded to: {file_path}") + + print(f"Extracting: {file_path}") + with tarfile.open(file_path, "r:gz") as tar: + tar.extractall(path=destination_directory) + print(f"Extracted to: {destination_directory}") + os.remove(file_path) + print(f"Deleted downloaded file: {file_path}") + except subprocess.CalledProcessError as e: + print(f"Error downloading file: {e}") + except Exception as e: + print(f"Error extracting file: {e}") + + def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR): """ build xpu plugin @@ -118,11 +152,50 @@ def xpu_setup_ops(): XDNN_PATH = os.getenv("XDNN_PATH") if XDNN_PATH is None: - XDNN_INC_PATH = os.path.join(PADDLE_INCLUDE_PATH) - XDNN_LIB_DIR = os.path.join(PADDLE_LIB_PATH) - else: - XDNN_INC_PATH = os.path.join(XDNN_PATH, "include") - XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so") + version_cmd = 'python -c "import paddle; print(paddle.version.xpu_xhpc())"' + try: + XHPC_VERSION = subprocess.check_output(version_cmd, shell=True).strip().decode() + print(f"Fetched XHPC_VERSION from paddle: {XHPC_VERSION}") + except Exception as e: + raise Exception(f"PaddlePaddle-xpu not installed, please install it first. {e}") + + XHPC_URL = f"https://klx-sdk-release-public.su.bcebos.com/xhpc/{XHPC_VERSION}/xhpc-ubuntu2004_x86_64.tar.gz" + THIRD_PARTY_PATH = os.path.join(current_file.parent, "third_party") + XHPC_PATH = os.path.join(THIRD_PARTY_PATH, "xhpc-ubuntu2004_x86_64") + + if os.path.exists(XHPC_PATH): + with open(os.path.join(XHPC_PATH, "version.txt")) as f: + date_line = [line.strip() for line in f.readlines() if "Date:" in line][0] + LOCAL_VERSION = f"dev/{date_line.split()[1]}" + if LOCAL_VERSION == XHPC_VERSION: + print("Local XHPC exists, skip downloading it again.") + else: + XHPC_UPDATE_POLICY_ENV = os.getenv("XHPC_UPDATE_POLICY") + if XHPC_UPDATE_POLICY_ENV is not None: + if XHPC_UPDATE_POLICY_ENV == "FORCE": + print("Forced update detected, downloading new XHPC.") + download_and_extract(XHPC_URL, THIRD_PARTY_PATH) + elif XHPC_UPDATE_POLICY_ENV == "SKIP": + print("Skipped updating XHPC.") + else: + raise Exception( + f"\033[91mInvalid value for environment variable XHPC_UPDATE_POLICY\033[0m: {XHPC_UPDATE_POLICY_ENV}, " + f"Valid environment values are FORCE or SKIP.", + ) + else: + raise Exception( + f"\033[91mLocal XHPC version mismatch\033[0m, expected {XHPC_VERSION}, found {LOCAL_VERSION} in {XHPC_PATH}. " + f"\nPlease set environment XHPC_UPDATE_POLICY and rebuild FastDeploy. " + f"\nexport XHPC_UPDATE_POLICY=FORCE for downloading version({XHPC_VERSION}) with force. " + f"\nexport XHPC_UPDATE_POLICY=SKIP for using local version({LOCAL_VERSION}).", + ) + else: + download_and_extract(XHPC_URL, THIRD_PARTY_PATH) + + XDNN_PATH = os.path.join(XHPC_PATH, "xdnn") + + XDNN_INC_PATH = os.path.join(XDNN_PATH, "include") + XDNN_LIB_DIR = os.path.join(XDNN_PATH, "so") XFA_PATH = os.getenv("XFA_PATH") if XFA_PATH is None: