Files
FastDeploy/custom_ops/setup_ops.py
Jiang-Jia-Jun 05c670e593 [Sync] Update to latest code (#2679)
* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-03 15:43:53 +08:00

513 lines
18 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" setup for FastDeploy custom ops """
import importlib
import json
import os
import shutil
import subprocess
import sys
import tarfile
from pathlib import Path
import paddle
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
from setuptools import find_namespace_packages, find_packages
def load_module_from_path(module_name, path):
"""
load python module from path
"""
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy,
# which is not installed yet
envs = load_module_from_path('envs',
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
def download_and_extract(url, destination_directory):
"""
Download a .tar.gz file using wget to the destination directory
and extract its contents without renaming the downloaded file.
:param url: The URL of the .tar.gz file to download.
:param destination_directory: The directory where the file should be downloaded and extracted.
"""
os.makedirs(destination_directory, exist_ok=True)
filename = os.path.basename(url)
file_path = os.path.join(destination_directory, filename)
try:
subprocess.run(
["wget", "-O", file_path, url],
check=True,
)
print(f"Downloaded: {file_path}")
with tarfile.open(file_path, "r:gz") as tar:
tar.extractall(path=destination_directory)
print(f"Extracted: {file_path} to {destination_directory}")
os.remove(file_path)
print(f"Deleted downloaded file: {file_path}")
except subprocess.CalledProcessError as e:
print(f"Error downloading file: {e}")
except Exception as e:
print(f"Error extracting file: {e}")
def clone_git_repo(version, repo_url, destination_path):
"""
Clone git repo to destination path.
"""
try:
subprocess.run(
[
"git",
"clone",
"-b",
version,
"--single-branch",
repo_url,
destination_path,
],
check=True,
)
return True
except subprocess.CalledProcessError:
return False
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
"""
reset git repo to destination commit and apply patch.
"""
if commit_id is not None:
reset_cmd = ["git", "reset", "--hard", commit_id]
if patch is not None:
patch_source = os.path.join(cur_path, patch)
patch_destination = os.path.join(dst_path, patch)
shutil.copy(patch_source, patch_destination)
apply_cmd = ["git", "apply", patch]
try:
os.chdir(dst_path)
if commit_id is not None:
subprocess.run(reset_cmd, check=True)
if patch is not None:
subprocess.run(apply_cmd, check=True)
os.chdir(cur_path)
return True
except subprocess.CalledProcessError:
return False
def get_sm_version(archs):
"""
Get sm version of paddle.
"""
arch_set = set(archs)
try:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
arch_set.add(cc)
except ValueError:
pass
return list(arch_set)
def get_nvcc_version():
"""
Get cuda version of nvcc.
"""
nvcc_output = subprocess.check_output(["nvcc", "--version"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0])
return nvcc_cuda_version
def get_gencode_flags(archs):
"""
Get gencode flags for current device or input.
"""
cc_s = get_sm_version(archs)
flags = []
for cc in cc_s:
if cc == 90:
cc = f"{cc}a"
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
else:
flags += ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
return flags
def find_end_files(directory, end_str):
"""
Find files with end str in directory.
"""
gen_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(end_str):
gen_files.append(os.path.join(root, file))
return gen_files
if paddle.is_compiled_with_rocm():
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
# so we need to check if paddle compiled with rocm at first.
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=[
"gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/stop_generation_multi_stop_seqs.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/fused_get_rope.cu",
"gpu_ops/transfer_output.cc",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/enforce_generation.cu",
"gpu_ops/tune_cublaslt_gemm.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"hipcc": [
"-O3",
"--gpu-max-threads-per-block=1024",
"-U__HIP_NO_HALF_OPERATORS__",
"-U__HIP_NO_HALF_CONVERSIONS__",
"-U__HIP_NO_BFLOAT16_OPERATORS__",
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
],
},
),
)
elif paddle.is_compiled_with_cuda():
sources = [
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/get_mm_split_fuse.cc",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/stop_generation_multi_stop_seqs.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/fused_get_rope.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/enforce_generation.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/tune_cublaslt_gemm.cu",
"gpu_ops/swap_cache_batch.cu",
"gpu_ops/swap_cache.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/cpp_extensions.cc",
"gpu_ops/share_external_data.cu",
"gpu_ops/per_token_quant_fp8.cu",
"gpu_ops/extract_text_token_output.cu",
"gpu_ops/update_split_fuse_input.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
]
# pd_disaggregation
sources += [
"gpu_ops/remote_cache_kv_ipc.cc",
"gpu_ops/open_shm_and_get_meta_signal.cc",
"gpu_ops/init_signal_layerwise.cc",
"gpu_ops/get_data_ptr_ipc.cu",
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
]
cutlass_dir = "third_party/cutlass"
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git",
cutlass_dir)
if not os.listdir(cutlass_dir):
raise ValueError("Git clone cutlass failed!")
# deep gemm
deep_gemm_dir = "third_party/DeepGEMM"
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
if not os.path.exists(deep_gemm_dir):
os.makedirs(deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git",
deep_gemm_dir)
if not os.listdir(deep_gemm_dir):
raise ValueError("Git clone DeepGEMM failed!")
cur_path = os.path.dirname(os.path.abspath(__file__))
dst_path = os.path.join(cur_path, deep_gemm_dir)
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
patch = "0001-DeepGEMM-95e81b3.patch"
process_git_repo(cur_path, dst_path, commit_id, patch)
dg_third_party_include_dirs = (
"third_party/cutlass/include/cute",
"third_party/cutlass/include/cutlass",
)
dg_include_dir = "third_party/DeepGEMM/deep_gemm/include"
os.makedirs(dg_include_dir, exist_ok=True)
for d in dg_third_party_include_dirs:
dirname = d.split("/")[-1]
src_dir = d
dst_dir = os.path.join(dg_include_dir, dirname)
# Remove existing directory if it exists
if os.path.exists(dst_dir):
if os.path.islink(dst_dir):
os.unlink(dst_dir)
else:
shutil.rmtree(dst_dir)
print(f"Copying {src_dir} to {dst_dir}")
# Copy the directory
try:
shutil.copytree(src_dir, dst_dir)
except Exception as e:
raise RuntimeError(
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git",
json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
nvcc_compile_args = get_gencode_flags(archs)
nvcc_compile_args += ["-DPADDLE_DEV"]
nvcc_compile_args += ["-DPADDLE_ON_INFERENCE"]
nvcc_compile_args += ["-DPy_LIMITED_API=0x03090000"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels",
"-Ithird_party/cutlass/include",
"-Ithird_party/cutlass/tools/util/include",
"-Igpu_ops/fp8_gemm_with_cutlass",
"-Igpu_ops",
"-Ithird_party/nlohmann_json/include",
]
nvcc_version = get_nvcc_version()
print(f'nvcc_version = {nvcc_version}')
if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(get_sm_version(archs))
print(f"cc = {cc}")
fp8_auto_gen_directory = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
if os.path.isdir(fp8_auto_gen_directory):
shutil.rmtree(fp8_auto_gen_directory)
if cc >= 75:
nvcc_compile_args += [
"-DENABLE_SCALED_MM_C2X=1",
"-Igpu_ops/cutlass_kernels/w8a8",
]
sources += [
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu",
"gpu_ops/quantization/common.cu",
]
if cc >= 80:
# append_attention
sources += ["gpu_ops/append_attention.cu"]
sources += find_end_files("gpu_ops/append_attn", ".cu")
# mla
sources += ["gpu_ops/multi_head_latent_attention.cu"]
# gemm_dequant
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
# speculate_decoding
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
nvcc_compile_args += ["-DENABLE_BF16"]
# moe
sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
sources += find_end_files("gpu_ops/moe/", ".cu")
nvcc_compile_args += ["-Igpu_ops/moe"]
if cc >= 89:
# Running generate fp8 gemm codes.
nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc < 90:
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else:
nvcc_compile_args += [
"-gencode",
"arch=compute_90a,code=compute_90a",
"-O3",
"-DNDEBUG",
]
os.system(
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
os.system(
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]
sources += [
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
"gpu_ops/fp8_gemm_with_cutlass/per_channel_fp8_fp8_half_gemm.cu",
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu",
"gpu_ops/scaled_gemm_f8_i4_f16_gemm.cu",
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
]
sources += find_end_files(fp8_auto_gen_directory, ".cu")
if cc >= 90 and nvcc_version >= 12.0:
# Hopper optmized mla
sources += find_end_files("gpu_ops/mla_attn", ".cu")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={"nvcc": nvcc_compile_args},
libraries=["cublasLt"],
),
packages=find_packages(where="third_party/DeepGEMM"),
package_dir={"": "third_party/DeepGEMM"},
package_data={
"deep_gemm": [
"include/deep_gemm/**/*",
"include/cute/**/*",
"include/cutlass/**/*",
]
},
include_package_data=True,
)
elif paddle.is_compiled_with_xpu():
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
# cc flags
paddle_extra_compile_args = [
"-std=c++17",
"-shared",
"-fPIC",
"-Wno-parentheses",
"-DPADDLE_WITH_CUSTOM_KERNEL",
"-DPADDLE_ON_INFERENCE",
"-Wall",
"-O3",
"-g",
"-lstdc++fs",
"-D_GLIBCXX_USE_CXX11_ABI=1",
"-DPy_LIMITED_API=0x03090000",
]
setup(
name="fastdeploy_cpu_ops",
ext_modules=CppExtension(
sources=[
"cpu_ops/simd_sort.cc",
"cpu_ops/set_value_by_flags.cc",
"cpu_ops/token_penalty_multi_scores.cc",
"cpu_ops/stop_generation_multi_ends.cc",
"cpu_ops/update_inputs.cc",
"cpu_ops/get_padding_offset.cc",
],
extra_link_args=[
"-Wl,-rpath,$ORIGIN/x86-simd-sort/builddir",
"-Wl,-rpath,$ORIGIN/xFasterTransformer/build",
],
extra_compile_args=paddle_extra_compile_args,
),
packages=find_namespace_packages(where="third_party"),
package_dir={"": "third_party"},
include_package_data=True,
)