mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00

Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
672 lines
26 KiB
Python
672 lines
26 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""setup for FastDeploy custom ops"""
|
|
import importlib
|
|
import json
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tarfile
|
|
from pathlib import Path
|
|
|
|
import paddle
|
|
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
|
|
from setuptools import find_namespace_packages, find_packages
|
|
|
|
|
|
def load_module_from_path(module_name, path):
|
|
"""
|
|
load python module from path
|
|
"""
|
|
spec = importlib.util.spec_from_file_location(module_name, path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def update_git_repo():
|
|
try:
|
|
print("update third party repo...", flush=True)
|
|
original_dir = os.getcwd()
|
|
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
|
third_party_path = os.path.join(submodule_dir, "third_party")
|
|
root_path = Path(third_party_path)
|
|
|
|
# check if third_party is empty
|
|
update_third_party = False
|
|
for dirpath in root_path.iterdir():
|
|
if dirpath.is_dir():
|
|
has_content = any(dirpath.iterdir())
|
|
if not has_content:
|
|
update_third_party = True
|
|
|
|
if update_third_party:
|
|
os.chdir(submodule_dir)
|
|
subprocess.run(
|
|
"git submodule sync --recursive && git submodule update --init --recursive",
|
|
shell=True,
|
|
check=True,
|
|
text=True,
|
|
)
|
|
else:
|
|
print(
|
|
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
|
flush=True,
|
|
)
|
|
|
|
# apply deep gemm patch
|
|
deep_gemm_dir = "third_party/DeepGEMM"
|
|
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
|
patch = "0001-DeepGEMM-95e81b3.patch"
|
|
patch_source = os.path.join(submodule_dir, patch)
|
|
patch_destination = os.path.join(dst_path, patch)
|
|
if not os.path.exists(patch_destination):
|
|
shutil.copy(patch_source, patch_destination)
|
|
apply_cmd = ["git", "apply", patch]
|
|
os.chdir(dst_path)
|
|
subprocess.run(apply_cmd, check=True)
|
|
os.chdir(original_dir)
|
|
except subprocess.CalledProcessError:
|
|
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
|
|
|
|
|
ROOT_DIR = Path(__file__).parent.parent
|
|
|
|
# cannot import envs directly because it depends on fastdeploy,
|
|
# which is not installed yet
|
|
envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
|
|
|
|
archs = json.loads(envs.FD_BUILDING_ARCS)
|
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
|
|
|
update_git_repo()
|
|
|
|
|
|
def download_and_extract(url, destination_directory):
|
|
"""
|
|
Download a .tar.gz file using wget to the destination directory
|
|
and extract its contents without renaming the downloaded file.
|
|
|
|
:param url: The URL of the .tar.gz file to download.
|
|
:param destination_directory: The directory where the file should be downloaded and extracted.
|
|
"""
|
|
os.makedirs(destination_directory, exist_ok=True)
|
|
|
|
filename = os.path.basename(url)
|
|
file_path = os.path.join(destination_directory, filename)
|
|
|
|
try:
|
|
subprocess.run(
|
|
["wget", "-O", file_path, url],
|
|
check=True,
|
|
)
|
|
print(f"Downloaded: {file_path}")
|
|
|
|
with tarfile.open(file_path, "r:gz") as tar:
|
|
tar.extractall(path=destination_directory)
|
|
print(f"Extracted: {file_path} to {destination_directory}")
|
|
os.remove(file_path)
|
|
print(f"Deleted downloaded file: {file_path}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error downloading file: {e}")
|
|
except Exception as e:
|
|
print(f"Error extracting file: {e}")
|
|
|
|
|
|
def get_sm_version(archs):
|
|
"""
|
|
Get sm version of paddle.
|
|
"""
|
|
arch_set = set(archs)
|
|
try:
|
|
prop = paddle.device.cuda.get_device_properties()
|
|
cc = prop.major * 10 + prop.minor
|
|
arch_set.add(cc)
|
|
except ValueError:
|
|
pass
|
|
return list(arch_set)
|
|
|
|
|
|
def get_nvcc_version():
|
|
"""
|
|
Get cuda version of nvcc.
|
|
"""
|
|
nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
|
|
output = nvcc_output.split()
|
|
release_idx = output.index("release") + 1
|
|
nvcc_cuda_version = float(output[release_idx].split(",")[0])
|
|
return nvcc_cuda_version
|
|
|
|
|
|
def get_gencode_flags(archs):
|
|
"""
|
|
Get gencode flags for current device or input.
|
|
"""
|
|
cc_s = get_sm_version(archs)
|
|
flags = []
|
|
for cc_val in cc_s:
|
|
if cc_val == 90:
|
|
arch_code = "90a"
|
|
flags += [
|
|
"-gencode",
|
|
f"arch=compute_{arch_code},code=sm_{arch_code}",
|
|
]
|
|
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
|
|
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
|
|
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
|
|
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
|
|
arch_code = "100a"
|
|
flags += [
|
|
"-gencode",
|
|
f"arch=compute_{arch_code},code=sm_{arch_code}",
|
|
]
|
|
else:
|
|
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
|
|
return flags
|
|
|
|
|
|
def find_end_files(directory, end_str):
|
|
"""
|
|
Find files with end str in directory.
|
|
"""
|
|
gen_files = []
|
|
for root, dirs, files in os.walk(directory):
|
|
for file in files:
|
|
if file.endswith(end_str):
|
|
gen_files.append(os.path.join(root, file))
|
|
return gen_files
|
|
|
|
|
|
if paddle.is_compiled_with_rocm():
|
|
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
|
# so we need to check if paddle compiled with rocm at first.
|
|
sources = [
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
"gpu_ops/save_output_msg_with_topk.cc",
|
|
"gpu_ops/transfer_output.cc",
|
|
"gpu_ops/set_value_by_flags.cu",
|
|
"gpu_ops/token_penalty_multi_scores.cu",
|
|
"gpu_ops/stop_generation.cu",
|
|
"gpu_ops/stop_generation_multi_ends.cu",
|
|
"gpu_ops/get_padding_offset.cu",
|
|
"gpu_ops/update_inputs.cu",
|
|
"gpu_ops/rebuild_padding.cu",
|
|
"gpu_ops/step.cu",
|
|
"gpu_ops/set_data_ipc.cu",
|
|
"gpu_ops/unset_data_ipc.cu",
|
|
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
|
"gpu_ops/step_system_cache.cu",
|
|
"gpu_ops/get_output_ep.cc",
|
|
"gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
|
|
"gpu_ops/speculate_decoding/speculate_get_output.cc",
|
|
"gpu_ops/share_external_data.cu",
|
|
"gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
|
|
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
|
|
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
|
|
"gpu_ops/speculate_decoding/speculate_save_output.cc",
|
|
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
|
|
"gpu_ops/speculate_decoding/speculate_step.cu",
|
|
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
|
|
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
|
|
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
|
"gpu_ops/step_reschedule.cu",
|
|
]
|
|
setup(
|
|
name="fastdeploy_ops",
|
|
ext_modules=CUDAExtension(
|
|
sources=sources,
|
|
extra_compile_args={
|
|
"cxx": ["-O3"],
|
|
"hipcc": [
|
|
"-O3",
|
|
"--gpu-max-threads-per-block=1024",
|
|
"-U__HIP_NO_HALF_OPERATORS__",
|
|
"-U__HIP_NO_HALF_CONVERSIONS__",
|
|
"-U__HIP_NO_BFLOAT16_OPERATORS__",
|
|
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
|
|
"-U__HIP_NO_BFLOAT162_OPERATORS__",
|
|
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
|
|
"-DPADDLE_DEV",
|
|
"-Ithird_party/nlohmann_json/include",
|
|
"-Igpu_ops",
|
|
],
|
|
},
|
|
),
|
|
)
|
|
elif paddle.is_compiled_with_cuda():
|
|
sources = [
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
"gpu_ops/save_output_msg_with_topk.cc",
|
|
"gpu_ops/transfer_output.cc",
|
|
"gpu_ops/set_mask_value.cu",
|
|
"gpu_ops/set_value_by_flags.cu",
|
|
"gpu_ops/ngram_mask.cu",
|
|
"gpu_ops/gather_idx.cu",
|
|
"gpu_ops/get_output_ep.cc",
|
|
"gpu_ops/get_mm_split_fuse.cc",
|
|
"gpu_ops/get_img_boundaries.cc",
|
|
"gpu_ops/token_penalty_multi_scores.cu",
|
|
"gpu_ops/token_penalty_only_once.cu",
|
|
"gpu_ops/stop_generation.cu",
|
|
"gpu_ops/stop_generation_multi_ends.cu",
|
|
"gpu_ops/set_flags.cu",
|
|
"gpu_ops/update_inputs_v1.cu",
|
|
"gpu_ops/recover_decode_task.cu",
|
|
"gpu_ops/step.cu",
|
|
"gpu_ops/step_reschedule.cu",
|
|
"gpu_ops/fused_get_rope.cu",
|
|
"gpu_ops/get_padding_offset.cu",
|
|
"gpu_ops/update_inputs.cu",
|
|
"gpu_ops/update_inputs_beam.cu",
|
|
"gpu_ops/beam_search_softmax.cu",
|
|
"gpu_ops/rebuild_padding.cu",
|
|
"gpu_ops/set_data_ipc.cu",
|
|
"gpu_ops/unset_data_ipc.cu",
|
|
"gpu_ops/read_data_ipc.cu",
|
|
"gpu_ops/enforce_generation.cu",
|
|
"gpu_ops/dequant_int8.cu",
|
|
"gpu_ops/tune_cublaslt_gemm.cu",
|
|
"gpu_ops/swap_cache_batch.cu",
|
|
"gpu_ops/swap_cache.cu",
|
|
"gpu_ops/step_system_cache.cu",
|
|
"gpu_ops/cpp_extensions.cc",
|
|
"gpu_ops/share_external_data.cu",
|
|
"gpu_ops/per_token_quant_fp8.cu",
|
|
"gpu_ops/extract_text_token_output.cu",
|
|
"gpu_ops/update_split_fuse_input.cu",
|
|
"gpu_ops/text_image_index_out.cu",
|
|
"gpu_ops/text_image_gather_scatter.cu",
|
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
|
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
|
|
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
|
"gpu_ops/noaux_tc.cu",
|
|
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
|
"gpu_ops/merge_prefill_decode_output.cu",
|
|
]
|
|
|
|
# pd_disaggregation
|
|
sources += [
|
|
"gpu_ops/remote_cache_kv_ipc.cc",
|
|
"gpu_ops/open_shm_and_get_meta_signal.cc",
|
|
"gpu_ops/init_signal_layerwise.cc",
|
|
"gpu_ops/get_data_ptr_ipc.cu",
|
|
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
|
]
|
|
|
|
dg_third_party_include_dirs = (
|
|
"third_party/cutlass/include/cute",
|
|
"third_party/cutlass/include/cutlass",
|
|
)
|
|
|
|
dg_include_dir = "third_party/DeepGEMM/deep_gemm/include"
|
|
os.makedirs(dg_include_dir, exist_ok=True)
|
|
|
|
for d in dg_third_party_include_dirs:
|
|
dirname = d.split("/")[-1]
|
|
src_dir = d
|
|
dst_dir = os.path.join(dg_include_dir, dirname)
|
|
|
|
# Remove existing directory if it exists
|
|
if os.path.exists(dst_dir):
|
|
if os.path.islink(dst_dir):
|
|
os.unlink(dst_dir)
|
|
else:
|
|
shutil.rmtree(dst_dir)
|
|
print(f"Copying {src_dir} to {dst_dir}")
|
|
|
|
# Copy the directory
|
|
try:
|
|
shutil.copytree(src_dir, dst_dir)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
|
|
|
cc_compile_args = []
|
|
nvcc_compile_args = get_gencode_flags(archs)
|
|
nvcc_compile_args += ["-DPADDLE_DEV"]
|
|
nvcc_compile_args += ["-DPADDLE_ON_INFERENCE"]
|
|
nvcc_compile_args += ["-DPy_LIMITED_API=0x03090000"]
|
|
nvcc_compile_args += [
|
|
"-Igpu_ops/cutlass_kernels",
|
|
"-Ithird_party/cutlass/include",
|
|
"-Ithird_party/cutlass/tools/util/include",
|
|
"-Igpu_ops/fp8_gemm_with_cutlass",
|
|
"-Igpu_ops",
|
|
"-Ithird_party/nlohmann_json/include",
|
|
]
|
|
|
|
nvcc_version = get_nvcc_version()
|
|
print(f"nvcc_version = {nvcc_version}")
|
|
if nvcc_version >= 12.0:
|
|
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
|
|
cc = max(get_sm_version(archs))
|
|
print(f"cc = {cc}")
|
|
fp8_auto_gen_directory = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
|
|
if os.path.isdir(fp8_auto_gen_directory):
|
|
shutil.rmtree(fp8_auto_gen_directory)
|
|
|
|
if cc >= 75:
|
|
nvcc_compile_args += [
|
|
"-DENABLE_SCALED_MM_C2X=1",
|
|
"-Igpu_ops/cutlass_kernels/w8a8",
|
|
]
|
|
sources += [
|
|
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu",
|
|
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu",
|
|
"gpu_ops/quantization/common.cu",
|
|
]
|
|
|
|
if cc >= 80:
|
|
# append_attention
|
|
sources += ["gpu_ops/append_attention.cu"]
|
|
sources += find_end_files("gpu_ops/append_attn", ".cu")
|
|
# mla
|
|
sources += ["gpu_ops/multi_head_latent_attention.cu"]
|
|
# gemm_dequant
|
|
sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
|
|
# speculate_decoding
|
|
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
|
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
|
|
nvcc_compile_args += ["-DENABLE_BF16"]
|
|
# moe
|
|
os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
|
|
sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
|
|
sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
|
|
sources += find_end_files("gpu_ops/moe/", ".cu")
|
|
nvcc_compile_args += ["-Igpu_ops/moe"]
|
|
|
|
if cc >= 89:
|
|
# Running generate fp8 gemm codes.
|
|
# Common for SM89, SM90, SM100 (Blackwell)
|
|
nvcc_compile_args += ["-DENABLE_FP8"]
|
|
nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
|
|
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
|
|
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
|
|
|
|
if cc >= 90: # Hopper and newer
|
|
# SM90 (Hopper) specific auto-generation and flags
|
|
if cc == 90: # Only for SM90
|
|
nvcc_compile_args += [
|
|
# The gencode for 90a is added in get_gencode_flags now
|
|
# "-gencode",
|
|
# "arch=compute_90a,code=compute_90a",
|
|
"-O3",
|
|
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
|
|
]
|
|
print("SM90: Running SM90-specific FP8 kernel auto-generation.")
|
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
|
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
|
|
os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
|
|
|
|
nvcc_compile_args += [
|
|
"-DENABLE_SCALED_MM_SM90=1",
|
|
]
|
|
sources += [
|
|
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
|
|
"gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
|
|
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
|
|
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
|
|
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
|
|
]
|
|
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
|
|
print("SM100 (Blackwell): Applying SM100 configurations.")
|
|
nvcc_compile_args += [
|
|
# The gencode for 100a is added in get_gencode_flags
|
|
# "-gencode",
|
|
# "arch=compute_100a,code=compute_100a",
|
|
"-O3", # Common optimization flag
|
|
"-DNDEBUG", # Common debug flag
|
|
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
|
|
]
|
|
# Placeholder for SM100-specific kernel auto-generation scripts
|
|
# These might be needed if Blackwell has new FP8 hardware features
|
|
# not covered by existing generic CUTLASS templates or SM90 scripts.
|
|
# print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
|
|
# os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
|
|
# os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
|
|
|
|
# Add SM100 specific sources if any, e.g., for new hardware intrinsics
|
|
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
|
|
pass # No SM100 specific sources identified yet beyond what CUTLASS handles
|
|
else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
|
|
print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
|
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
|
|
|
else: # For cc == 89 (Ada)
|
|
print("SM89: Running generic FP8 kernel auto-generation.")
|
|
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
|
|
os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
|
|
|
|
# Common FP8 sources for SM89+
|
|
sources += [
|
|
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
|
|
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
|
|
"gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
|
|
"gpu_ops/fp8_gemm_with_cutlass/per_channel_fp8_fp8_half_gemm.cu",
|
|
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu",
|
|
"gpu_ops/scaled_gemm_f8_i4_f16_gemm.cu",
|
|
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
|
|
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
|
|
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
|
|
"gpu_ops/fused_hadamard_quant_fp8.cu",
|
|
]
|
|
|
|
sources += find_end_files(fp8_auto_gen_directory, ".cu")
|
|
|
|
if cc >= 90 and nvcc_version >= 12.0:
|
|
# Hopper optmized mla
|
|
sources += find_end_files("gpu_ops/mla_attn", ".cu")
|
|
sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
|
|
sources += find_end_files("gpu_ops/moba_attn/moba_decoder_attn/", ".cu")
|
|
sources += find_end_files("gpu_ops/moba_attn/moba_encoder_attn/", ".cu")
|
|
sources += find_end_files("gpu_ops/moba_attn/moba_process/", ".cu")
|
|
sources += ["gpu_ops/moba_attn/moba_attn.cu"]
|
|
os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
|
|
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
|
|
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")
|
|
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
|
|
os.system("python gpu_ops/machete/generate.py")
|
|
sources += find_end_files("gpu_ops/machete", ".cu")
|
|
cc_compile_args += ["-DENABLE_MACHETE"]
|
|
|
|
setup(
|
|
name="fastdeploy_ops",
|
|
ext_modules=CUDAExtension(
|
|
sources=sources,
|
|
extra_compile_args={"cxx": cc_compile_args, "nvcc": nvcc_compile_args},
|
|
libraries=["cublasLt"],
|
|
extra_link_args=["-lcuda"],
|
|
),
|
|
packages=find_packages(where="third_party/DeepGEMM"),
|
|
package_dir={"": "third_party/DeepGEMM"},
|
|
package_data={
|
|
"deep_gemm": [
|
|
"include/deep_gemm/**/*",
|
|
"include/cute/**/*",
|
|
"include/cutlass/**/*",
|
|
]
|
|
},
|
|
include_package_data=True,
|
|
)
|
|
elif paddle.is_compiled_with_xpu():
|
|
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
|
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
|
setup(
|
|
name="fastdeploy_ops",
|
|
ext_modules=CUDAExtension(
|
|
extra_compile_args={
|
|
"nvcc": [
|
|
"-DPADDLE_DEV",
|
|
"-DPADDLE_WITH_CUSTOM_DEVICE",
|
|
]
|
|
},
|
|
sources=[
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
"gpu_ops/save_output_msg_with_topk.cc",
|
|
"gpu_ops/transfer_output.cc",
|
|
"gpu_ops/get_padding_offset.cu",
|
|
"gpu_ops/set_value_by_flags.cu",
|
|
"gpu_ops/rebuild_padding.cu",
|
|
"gpu_ops/update_inputs.cu",
|
|
"gpu_ops/stop_generation_multi_ends.cu",
|
|
"gpu_ops/step.cu",
|
|
"gpu_ops/token_penalty_multi_scores.cu",
|
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
|
"iluvatar_ops/moe_dispatch.cu",
|
|
"iluvatar_ops/moe_reduce.cu",
|
|
"iluvatar_ops/paged_attn.cu",
|
|
"iluvatar_ops/w8a16_group_gemm.cu",
|
|
"iluvatar_ops/runtime/iluvatar_context.cc",
|
|
],
|
|
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
|
|
extra_link_args=[
|
|
"-lcuinfer",
|
|
],
|
|
),
|
|
)
|
|
elif paddle.is_compiled_with_custom_device("gcu"):
|
|
setup(
|
|
name="fastdeploy_ops",
|
|
ext_modules=CppExtension(
|
|
sources=[
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
]
|
|
),
|
|
)
|
|
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
|
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
|
sources = [
|
|
"gpu_ops/update_inputs_v1.cu",
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
"gpu_ops/save_output_msg_with_topk.cc",
|
|
"gpu_ops/transfer_output.cc",
|
|
"gpu_ops/save_with_output.cc",
|
|
"gpu_ops/set_mask_value.cu",
|
|
"gpu_ops/set_value_by_flags.cu",
|
|
"gpu_ops/ngram_mask.cu",
|
|
"gpu_ops/gather_idx.cu",
|
|
"gpu_ops/get_output_ep.cc",
|
|
"gpu_ops/token_penalty_multi_scores.cu",
|
|
"gpu_ops/token_penalty_only_once.cu",
|
|
"gpu_ops/stop_generation.cu",
|
|
"gpu_ops/stop_generation_multi_ends.cu",
|
|
"gpu_ops/set_flags.cu",
|
|
"gpu_ops/fused_get_rope.cu",
|
|
"gpu_ops/get_padding_offset.cu",
|
|
"gpu_ops/update_inputs.cu",
|
|
"gpu_ops/update_inputs_beam.cu",
|
|
"gpu_ops/beam_search_softmax.cu",
|
|
"gpu_ops/rebuild_padding.cu",
|
|
"gpu_ops/step.cu",
|
|
"gpu_ops/step_reschedule.cu",
|
|
"gpu_ops/step_system_cache.cu",
|
|
"gpu_ops/set_data_ipc.cu",
|
|
"gpu_ops/read_data_ipc.cu",
|
|
"gpu_ops/dequant_int8.cu",
|
|
"gpu_ops/share_external_data.cu",
|
|
"gpu_ops/extract_text_token_output.cu",
|
|
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
|
"gpu_ops/moe/moe_topk_select.cu",
|
|
"gpu_ops/recover_decode_task.cu",
|
|
]
|
|
|
|
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
|
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
|
|
|
|
setup(
|
|
name="fastdeploy_ops",
|
|
ext_modules=CUDAExtension(
|
|
sources=sources,
|
|
extra_compile_args={
|
|
"cxx": ["-O3"],
|
|
"nvcc": [
|
|
"-O3",
|
|
"-Ithird_party/nlohmann_json/include",
|
|
"-Igpu_ops",
|
|
"-DPADDLE_DEV",
|
|
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
|
|
],
|
|
},
|
|
library_dirs=[os.path.join(maca_path, "lib")],
|
|
extra_link_args=["-lruntime_cu"],
|
|
include_dirs=[
|
|
os.path.join(maca_path, "include"),
|
|
os.path.join(maca_path, "include/mcr"),
|
|
os.path.join(maca_path, "include/common"),
|
|
],
|
|
),
|
|
)
|
|
else:
|
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
|
|
|
# cc flags
|
|
paddle_extra_compile_args = [
|
|
"-std=c++17",
|
|
"-shared",
|
|
"-fPIC",
|
|
"-Wno-parentheses",
|
|
"-DPADDLE_WITH_CUSTOM_KERNEL",
|
|
"-DPADDLE_ON_INFERENCE",
|
|
"-Wall",
|
|
"-O3",
|
|
"-g",
|
|
"-lstdc++fs",
|
|
"-D_GLIBCXX_USE_CXX11_ABI=1",
|
|
"-DPy_LIMITED_API=0x03090000",
|
|
]
|
|
|
|
setup(
|
|
name="fastdeploy_cpu_ops",
|
|
ext_modules=CppExtension(
|
|
sources=[
|
|
"gpu_ops/save_with_output_msg.cc",
|
|
"gpu_ops/get_output.cc",
|
|
"gpu_ops/get_output_msg_with_topk.cc",
|
|
"gpu_ops/save_output_msg_with_topk.cc",
|
|
"gpu_ops/transfer_output.cc",
|
|
"cpu_ops/rebuild_padding.cc",
|
|
"cpu_ops/simd_sort.cc",
|
|
"cpu_ops/set_value_by_flags.cc",
|
|
"cpu_ops/token_penalty_multi_scores.cc",
|
|
"cpu_ops/stop_generation_multi_ends.cc",
|
|
"cpu_ops/update_inputs.cc",
|
|
"cpu_ops/get_padding_offset.cc",
|
|
],
|
|
extra_link_args=[
|
|
"-Wl,-rpath,$ORIGIN/x86-simd-sort/builddir",
|
|
"-Wl,-rpath,$ORIGIN/xFasterTransformer/build",
|
|
],
|
|
extra_compile_args=paddle_extra_compile_args,
|
|
),
|
|
packages=find_namespace_packages(where="third_party"),
|
|
package_dir={"": "third_party"},
|
|
include_package_data=True,
|
|
)
|