mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	 f0f00a6025
			
		
	
	f0f00a6025
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			* delete nonzero * delete setup_ops_base.py * check if * check gcp infer_seed.cpu() * fix repetition_early_stopper_kernel cuda 700
		
			
				
	
	
		
			694 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			694 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """setup for FastDeploy custom ops"""
 | |
| import importlib
 | |
| import json
 | |
| import os
 | |
| import shutil
 | |
| import subprocess
 | |
| import sys
 | |
| import tarfile
 | |
| from pathlib import Path
 | |
| 
 | |
| import paddle
 | |
| from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
 | |
| from setuptools import find_namespace_packages, find_packages
 | |
| 
 | |
| 
 | |
| def load_module_from_path(module_name, path):
 | |
|     """
 | |
|     load python module from path
 | |
|     """
 | |
|     spec = importlib.util.spec_from_file_location(module_name, path)
 | |
|     module = importlib.util.module_from_spec(spec)
 | |
|     sys.modules[module_name] = module
 | |
|     spec.loader.exec_module(module)
 | |
|     return module
 | |
| 
 | |
| 
 | |
| ROOT_DIR = Path(__file__).parent.parent
 | |
| 
 | |
| # cannot import envs directly because it depends on fastdeploy,
 | |
| #  which is not installed yet
 | |
| envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
 | |
| 
 | |
| archs = json.loads(envs.FD_BUILDING_ARCS)
 | |
| use_bf16 = envs.FD_CPU_USE_BF16 == "True"
 | |
| 
 | |
| 
 | |
| def download_and_extract(url, destination_directory):
 | |
|     """
 | |
|     Download a .tar.gz file using wget to the destination directory
 | |
|     and extract its contents without renaming the downloaded file.
 | |
| 
 | |
|     :param url: The URL of the .tar.gz file to download.
 | |
|     :param destination_directory: The directory where the file should be downloaded and extracted.
 | |
|     """
 | |
|     os.makedirs(destination_directory, exist_ok=True)
 | |
| 
 | |
|     filename = os.path.basename(url)
 | |
|     file_path = os.path.join(destination_directory, filename)
 | |
| 
 | |
|     try:
 | |
|         subprocess.run(
 | |
|             ["wget", "-O", file_path, url],
 | |
|             check=True,
 | |
|         )
 | |
|         print(f"Downloaded: {file_path}")
 | |
| 
 | |
|         with tarfile.open(file_path, "r:gz") as tar:
 | |
|             tar.extractall(path=destination_directory)
 | |
|             print(f"Extracted: {file_path} to {destination_directory}")
 | |
|         os.remove(file_path)
 | |
|         print(f"Deleted downloaded file: {file_path}")
 | |
|     except subprocess.CalledProcessError as e:
 | |
|         print(f"Error downloading file: {e}")
 | |
|     except Exception as e:
 | |
|         print(f"Error extracting file: {e}")
 | |
| 
 | |
| 
 | |
| def clone_git_repo(version, repo_url, destination_path):
 | |
|     """
 | |
|     Clone git repo to destination path.
 | |
|     """
 | |
|     try:
 | |
|         subprocess.run(
 | |
|             [
 | |
|                 "git",
 | |
|                 "clone",
 | |
|                 "-b",
 | |
|                 version,
 | |
|                 "--single-branch",
 | |
|                 repo_url,
 | |
|                 destination_path,
 | |
|             ],
 | |
|             check=True,
 | |
|         )
 | |
|         return True
 | |
|     except subprocess.CalledProcessError:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
 | |
|     """
 | |
|     reset git repo to destination commit and apply patch.
 | |
|     """
 | |
|     if commit_id is not None:
 | |
|         reset_cmd = ["git", "reset", "--hard", commit_id]
 | |
|     if patch is not None:
 | |
|         patch_source = os.path.join(cur_path, patch)
 | |
|         patch_destination = os.path.join(dst_path, patch)
 | |
|         shutil.copy(patch_source, patch_destination)
 | |
|         apply_cmd = ["git", "apply", patch]
 | |
| 
 | |
|     try:
 | |
|         os.chdir(dst_path)
 | |
|         if commit_id is not None:
 | |
|             subprocess.run(reset_cmd, check=True)
 | |
|         if patch is not None:
 | |
|             subprocess.run(apply_cmd, check=True)
 | |
|         os.chdir(cur_path)
 | |
|         return True
 | |
|     except subprocess.CalledProcessError:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| def get_sm_version(archs):
 | |
|     """
 | |
|     Get sm version of paddle.
 | |
|     """
 | |
|     arch_set = set(archs)
 | |
|     try:
 | |
|         prop = paddle.device.cuda.get_device_properties()
 | |
|         cc = prop.major * 10 + prop.minor
 | |
|         arch_set.add(cc)
 | |
|     except ValueError:
 | |
|         pass
 | |
|     return list(arch_set)
 | |
| 
 | |
| 
 | |
| def get_nvcc_version():
 | |
|     """
 | |
|     Get cuda version of nvcc.
 | |
|     """
 | |
|     nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
 | |
|     output = nvcc_output.split()
 | |
|     release_idx = output.index("release") + 1
 | |
|     nvcc_cuda_version = float(output[release_idx].split(",")[0])
 | |
|     return nvcc_cuda_version
 | |
| 
 | |
| 
 | |
| def get_gencode_flags(archs):
 | |
|     """
 | |
|     Get gencode flags for current device or input.
 | |
|     """
 | |
|     cc_s = get_sm_version(archs)
 | |
|     flags = []
 | |
|     for cc_val in cc_s:
 | |
|         if cc_val == 90:
 | |
|             arch_code = "90a"
 | |
|             flags += [
 | |
|                 "-gencode",
 | |
|                 f"arch=compute_{arch_code},code=sm_{arch_code}",
 | |
|             ]
 | |
|         elif cc_val == 100:  # Assuming 100 is the code for Blackwell SM10.x
 | |
|             # Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
 | |
|             # https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
 | |
|             # "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
 | |
|             arch_code = "100a"
 | |
|             flags += [
 | |
|                 "-gencode",
 | |
|                 f"arch=compute_{arch_code},code=sm_{arch_code}",
 | |
|             ]
 | |
|         else:
 | |
|             flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
 | |
|     return flags
 | |
| 
 | |
| 
 | |
| def find_end_files(directory, end_str):
 | |
|     """
 | |
|     Find files with end str in directory.
 | |
|     """
 | |
|     gen_files = []
 | |
|     for root, dirs, files in os.walk(directory):
 | |
|         for file in files:
 | |
|             if file.endswith(end_str):
 | |
|                 gen_files.append(os.path.join(root, file))
 | |
|     return gen_files
 | |
| 
 | |
| 
 | |
| if paddle.is_compiled_with_rocm():
 | |
|     # NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
 | |
|     # so we need to check if paddle compiled with rocm at first.
 | |
|     json_dir = "third_party/nlohmann_json"
 | |
|     if not os.path.exists(json_dir) or not os.listdir(json_dir):
 | |
|         if not os.path.exists(json_dir):
 | |
|             os.makedirs(json_dir)
 | |
|         clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
 | |
|         if not os.listdir(json_dir):
 | |
|             raise ValueError("Git clone nlohmann_json failed!")
 | |
|     sources = [
 | |
|         "gpu_ops/save_with_output_msg.cc",
 | |
|         "gpu_ops/get_output.cc",
 | |
|         "gpu_ops/get_output_msg_with_topk.cc",
 | |
|         "gpu_ops/save_output_msg_with_topk.cc",
 | |
|         "gpu_ops/transfer_output.cc",
 | |
|         "gpu_ops/set_value_by_flags.cu",
 | |
|         "gpu_ops/token_penalty_multi_scores.cu",
 | |
|         "gpu_ops/stop_generation.cu",
 | |
|         "gpu_ops/stop_generation_multi_ends.cu",
 | |
|         "gpu_ops/get_padding_offset.cu",
 | |
|         "gpu_ops/update_inputs.cu",
 | |
|         "gpu_ops/rebuild_padding.cu",
 | |
|         "gpu_ops/step.cu",
 | |
|         "gpu_ops/set_data_ipc.cu",
 | |
|         "gpu_ops/moe/tritonmoe_preprocess.cu",
 | |
|         "gpu_ops/step_system_cache.cu",
 | |
|         "gpu_ops/get_output_ep.cc",
 | |
|         "gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_get_output.cc",
 | |
|         "gpu_ops/share_external_data.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_save_output.cc",
 | |
|         "gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_step.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
 | |
|         "gpu_ops/speculate_decoding/speculate_update_v3.cu",
 | |
|         "gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
 | |
|         "gpu_ops/fused_rotary_position_encoding.cu",
 | |
|         "gpu_ops/step_reschedule.cu",
 | |
|     ]
 | |
|     setup(
 | |
|         name="fastdeploy_ops",
 | |
|         ext_modules=CUDAExtension(
 | |
|             sources=sources,
 | |
|             extra_compile_args={
 | |
|                 "cxx": ["-O3"],
 | |
|                 "hipcc": [
 | |
|                     "-O3",
 | |
|                     "--gpu-max-threads-per-block=1024",
 | |
|                     "-U__HIP_NO_HALF_OPERATORS__",
 | |
|                     "-U__HIP_NO_HALF_CONVERSIONS__",
 | |
|                     "-U__HIP_NO_BFLOAT16_OPERATORS__",
 | |
|                     "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
 | |
|                     "-U__HIP_NO_BFLOAT162_OPERATORS__",
 | |
|                     "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
 | |
|                     "-DPADDLE_DEV",
 | |
|                     "-Ithird_party/nlohmann_json/include",
 | |
|                     "-Igpu_ops",
 | |
|                 ],
 | |
|             },
 | |
|         ),
 | |
|     )
 | |
| elif paddle.is_compiled_with_cuda():
 | |
|     sources = [
 | |
|         "gpu_ops/save_with_output_msg.cc",
 | |
|         "gpu_ops/get_output.cc",
 | |
|         "gpu_ops/get_output_msg_with_topk.cc",
 | |
|         "gpu_ops/save_output_msg_with_topk.cc",
 | |
|         "gpu_ops/transfer_output.cc",
 | |
|         "gpu_ops/set_mask_value.cu",
 | |
|         "gpu_ops/set_value_by_flags.cu",
 | |
|         "gpu_ops/ngram_mask.cu",
 | |
|         "gpu_ops/gather_idx.cu",
 | |
|         "gpu_ops/get_output_ep.cc",
 | |
|         "gpu_ops/get_mm_split_fuse.cc",
 | |
|         "gpu_ops/get_img_boundaries.cc",
 | |
|         "gpu_ops/token_penalty_multi_scores.cu",
 | |
|         "gpu_ops/token_penalty_only_once.cu",
 | |
|         "gpu_ops/stop_generation.cu",
 | |
|         "gpu_ops/stop_generation_multi_ends.cu",
 | |
|         "gpu_ops/set_flags.cu",
 | |
|         "gpu_ops/update_inputs_v1.cu",
 | |
|         "gpu_ops/recover_decode_task.cu",
 | |
|         "gpu_ops/step.cu",
 | |
|         "gpu_ops/step_reschedule.cu",
 | |
|         "gpu_ops/fused_get_rope.cu",
 | |
|         "gpu_ops/get_padding_offset.cu",
 | |
|         "gpu_ops/update_inputs.cu",
 | |
|         "gpu_ops/update_inputs_beam.cu",
 | |
|         "gpu_ops/beam_search_softmax.cu",
 | |
|         "gpu_ops/rebuild_padding.cu",
 | |
|         "gpu_ops/set_data_ipc.cu",
 | |
|         "gpu_ops/read_data_ipc.cu",
 | |
|         "gpu_ops/enforce_generation.cu",
 | |
|         "gpu_ops/dequant_int8.cu",
 | |
|         "gpu_ops/tune_cublaslt_gemm.cu",
 | |
|         "gpu_ops/swap_cache_batch.cu",
 | |
|         "gpu_ops/swap_cache.cu",
 | |
|         "gpu_ops/step_system_cache.cu",
 | |
|         "gpu_ops/cpp_extensions.cc",
 | |
|         "gpu_ops/share_external_data.cu",
 | |
|         "gpu_ops/per_token_quant_fp8.cu",
 | |
|         "gpu_ops/extract_text_token_output.cu",
 | |
|         "gpu_ops/update_split_fuse_input.cu",
 | |
|         "gpu_ops/text_image_index_out.cu",
 | |
|         "gpu_ops/text_image_gather_scatter.cu",
 | |
|         "gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
 | |
|         "gpu_ops/sample_kernels/top_k_renorm_probs.cu",
 | |
|         "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
 | |
|         "gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
 | |
|         "gpu_ops/fused_rotary_position_encoding.cu",
 | |
|         "gpu_ops/noaux_tc.cu",
 | |
|         "gpu_ops/custom_all_reduce/all_reduce.cu",
 | |
|         "gpu_ops/merge_prefill_decode_output.cu",
 | |
|     ]
 | |
| 
 | |
|     # pd_disaggregation
 | |
|     sources += [
 | |
|         "gpu_ops/remote_cache_kv_ipc.cc",
 | |
|         "gpu_ops/open_shm_and_get_meta_signal.cc",
 | |
|         "gpu_ops/init_signal_layerwise.cc",
 | |
|         "gpu_ops/get_data_ptr_ipc.cu",
 | |
|         "gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
 | |
|     ]
 | |
| 
 | |
|     cutlass_dir = "third_party/cutlass"
 | |
|     if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
 | |
|         if not os.path.exists(cutlass_dir):
 | |
|             os.makedirs(cutlass_dir)
 | |
|         clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
 | |
|         if not os.listdir(cutlass_dir):
 | |
|             raise ValueError("Git clone cutlass failed!")
 | |
| 
 | |
|     # deep gemm
 | |
|     deep_gemm_dir = "third_party/DeepGEMM"
 | |
|     if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
 | |
|         if not os.path.exists(deep_gemm_dir):
 | |
|             os.makedirs(deep_gemm_dir)
 | |
|         clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
 | |
|         if not os.listdir(deep_gemm_dir):
 | |
|             raise ValueError("Git clone DeepGEMM failed!")
 | |
|         cur_path = os.path.dirname(os.path.abspath(__file__))
 | |
|         dst_path = os.path.join(cur_path, deep_gemm_dir)
 | |
|         commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
 | |
|         patch = "0001-DeepGEMM-95e81b3.patch"
 | |
|         process_git_repo(cur_path, dst_path, commit_id, patch)
 | |
| 
 | |
|     dg_third_party_include_dirs = (
 | |
|         "third_party/cutlass/include/cute",
 | |
|         "third_party/cutlass/include/cutlass",
 | |
|     )
 | |
| 
 | |
|     dg_include_dir = "third_party/DeepGEMM/deep_gemm/include"
 | |
|     os.makedirs(dg_include_dir, exist_ok=True)
 | |
| 
 | |
|     for d in dg_third_party_include_dirs:
 | |
|         dirname = d.split("/")[-1]
 | |
|         src_dir = d
 | |
|         dst_dir = os.path.join(dg_include_dir, dirname)
 | |
| 
 | |
|         # Remove existing directory if it exists
 | |
|         if os.path.exists(dst_dir):
 | |
|             if os.path.islink(dst_dir):
 | |
|                 os.unlink(dst_dir)
 | |
|             else:
 | |
|                 shutil.rmtree(dst_dir)
 | |
|         print(f"Copying {src_dir} to {dst_dir}")
 | |
| 
 | |
|         # Copy the directory
 | |
|         try:
 | |
|             shutil.copytree(src_dir, dst_dir)
 | |
|         except Exception as e:
 | |
|             raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
 | |
| 
 | |
|     json_dir = "third_party/nlohmann_json"
 | |
|     if not os.path.exists(json_dir) or not os.listdir(json_dir):
 | |
|         if not os.path.exists(json_dir):
 | |
|             os.makedirs(json_dir)
 | |
|         clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
 | |
|         if not os.listdir(json_dir):
 | |
|             raise ValueError("Git clone nlohmann_json failed!")
 | |
| 
 | |
|     nvcc_compile_args = get_gencode_flags(archs)
 | |
|     nvcc_compile_args += ["-DPADDLE_DEV"]
 | |
|     nvcc_compile_args += ["-DPADDLE_ON_INFERENCE"]
 | |
|     nvcc_compile_args += ["-DPy_LIMITED_API=0x03090000"]
 | |
|     nvcc_compile_args += [
 | |
|         "-Igpu_ops/cutlass_kernels",
 | |
|         "-Ithird_party/cutlass/include",
 | |
|         "-Ithird_party/cutlass/tools/util/include",
 | |
|         "-Igpu_ops/fp8_gemm_with_cutlass",
 | |
|         "-Igpu_ops",
 | |
|         "-Ithird_party/nlohmann_json/include",
 | |
|     ]
 | |
|     nvcc_version = get_nvcc_version()
 | |
|     print(f"nvcc_version = {nvcc_version}")
 | |
|     if nvcc_version >= 12.0:
 | |
|         sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
 | |
|     cc = max(get_sm_version(archs))
 | |
|     print(f"cc = {cc}")
 | |
|     fp8_auto_gen_directory = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
 | |
|     if os.path.isdir(fp8_auto_gen_directory):
 | |
|         shutil.rmtree(fp8_auto_gen_directory)
 | |
| 
 | |
|     if cc >= 75:
 | |
|         nvcc_compile_args += [
 | |
|             "-DENABLE_SCALED_MM_C2X=1",
 | |
|             "-Igpu_ops/cutlass_kernels/w8a8",
 | |
|         ]
 | |
|         sources += [
 | |
|             "gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu",
 | |
|             "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu",
 | |
|             "gpu_ops/quantization/common.cu",
 | |
|         ]
 | |
| 
 | |
|     if cc >= 80:
 | |
|         # append_attention
 | |
|         sources += ["gpu_ops/append_attention.cu"]
 | |
|         sources += find_end_files("gpu_ops/append_attn", ".cu")
 | |
|         # mla
 | |
|         sources += ["gpu_ops/multi_head_latent_attention.cu"]
 | |
|         # gemm_dequant
 | |
|         sources += ["gpu_ops/int8_gemm_with_cutlass/gemm_dequant.cu"]
 | |
|         # speculate_decoding
 | |
|         sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
 | |
|         sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
 | |
|         nvcc_compile_args += ["-DENABLE_BF16"]
 | |
|         # moe
 | |
|         os.system("python gpu_ops/moe/moe_wna16_marlin_utils/generate_kernels.py")
 | |
|         sources += find_end_files("gpu_ops/cutlass_kernels/moe_gemm/", ".cu")
 | |
|         sources += find_end_files("gpu_ops/cutlass_kernels/w4a8_moe/", ".cu")
 | |
|         sources += find_end_files("gpu_ops/moe/", ".cu")
 | |
|         nvcc_compile_args += ["-Igpu_ops/moe"]
 | |
| 
 | |
|     if cc >= 89:
 | |
|         # Running generate fp8 gemm codes.
 | |
|         # Common for SM89, SM90, SM100 (Blackwell)
 | |
|         nvcc_compile_args += ["-DENABLE_FP8"]
 | |
|         nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
 | |
|         # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
 | |
|         os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
 | |
| 
 | |
|         if cc >= 90:  # Hopper and newer
 | |
|             # SM90 (Hopper) specific auto-generation and flags
 | |
|             if cc == 90:  # Only for SM90
 | |
|                 nvcc_compile_args += [
 | |
|                     # The gencode for 90a is added in get_gencode_flags now
 | |
|                     # "-gencode",
 | |
|                     # "arch=compute_90a,code=compute_90a",
 | |
|                     "-O3",
 | |
|                     "-DNDEBUG",  # NDEBUG is common, consider moving if not specific to 90a
 | |
|                 ]
 | |
|                 print("SM90: Running SM90-specific FP8 kernel auto-generation.")
 | |
|                 os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
 | |
|                 os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
 | |
|                 os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
 | |
| 
 | |
|                 nvcc_compile_args += [
 | |
|                     "-DENABLE_SCALED_MM_SM90=1",
 | |
|                 ]
 | |
|                 sources += [
 | |
|                     "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu",
 | |
|                     "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu",
 | |
|                     "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu",
 | |
|                     "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
 | |
|                     "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
 | |
|                 ]
 | |
|             elif cc == 100 and nvcc_version >= 12.9:  # Blackwell SM100 specifics
 | |
|                 print("SM100 (Blackwell): Applying SM100 configurations.")
 | |
|                 nvcc_compile_args += [
 | |
|                     # The gencode for 100a is added in get_gencode_flags
 | |
|                     # "-gencode",
 | |
|                     # "arch=compute_100a,code=compute_100a",
 | |
|                     "-O3",  # Common optimization flag
 | |
|                     "-DNDEBUG",  # Common debug flag
 | |
|                     # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
 | |
|                 ]
 | |
|                 # Placeholder for SM100-specific kernel auto-generation scripts
 | |
|                 # These might be needed if Blackwell has new FP8 hardware features
 | |
|                 # not covered by existing generic CUTLASS templates or SM90 scripts.
 | |
|                 # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).")
 | |
|                 # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example
 | |
|                 # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example
 | |
| 
 | |
|                 # Add SM100 specific sources if any, e.g., for new hardware intrinsics
 | |
|                 # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
 | |
|                 pass  # No SM100 specific sources identified yet beyond what CUTLASS handles
 | |
|             else:  # For cc >= 89 but not 90 or 100 (e.g. SM89)
 | |
|                 print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
 | |
|                 os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
 | |
|                 os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
 | |
| 
 | |
|         else:  # For cc == 89 (Ada)
 | |
|             print("SM89: Running generic FP8 kernel auto-generation.")
 | |
|             os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
 | |
|             os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
 | |
| 
 | |
|         # Common FP8 sources for SM89+
 | |
|         sources += [
 | |
|             "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu",
 | |
|             "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
 | |
|             "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.cu",
 | |
|             "gpu_ops/fp8_gemm_with_cutlass/per_channel_fp8_fp8_half_gemm.cu",
 | |
|             "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu",
 | |
|             "gpu_ops/scaled_gemm_f8_i4_f16_gemm.cu",
 | |
|             "gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
 | |
|             "gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
 | |
|             "gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
 | |
|             "gpu_ops/fused_hadamard_quant_fp8.cu",
 | |
|         ]
 | |
| 
 | |
|         sources += find_end_files(fp8_auto_gen_directory, ".cu")
 | |
| 
 | |
|     if cc >= 90 and nvcc_version >= 12.0:
 | |
|         # Hopper optmized mla
 | |
|         sources += find_end_files("gpu_ops/mla_attn", ".cu")
 | |
|         sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"]
 | |
|         os.system("python utils/auto_gen_w4afp8_gemm_kernel.py")
 | |
|         sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
 | |
| 
 | |
|     setup(
 | |
|         name="fastdeploy_ops",
 | |
|         ext_modules=CUDAExtension(
 | |
|             sources=sources,
 | |
|             extra_compile_args={"nvcc": nvcc_compile_args},
 | |
|             libraries=["cublasLt"],
 | |
|         ),
 | |
|         packages=find_packages(where="third_party/DeepGEMM"),
 | |
|         package_dir={"": "third_party/DeepGEMM"},
 | |
|         package_data={
 | |
|             "deep_gemm": [
 | |
|                 "include/deep_gemm/**/*",
 | |
|                 "include/cute/**/*",
 | |
|                 "include/cutlass/**/*",
 | |
|             ]
 | |
|         },
 | |
|         include_package_data=True,
 | |
|     )
 | |
| elif paddle.is_compiled_with_xpu():
 | |
|     assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
 | |
| elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
 | |
|     setup(
 | |
|         name="fastdeploy_ops",
 | |
|         ext_modules=CUDAExtension(
 | |
|             extra_compile_args={
 | |
|                 "nvcc": [
 | |
|                     "-DPADDLE_DEV",
 | |
|                     "-DPADDLE_WITH_CUSTOM_DEVICE",
 | |
|                 ]
 | |
|             },
 | |
|             sources=[
 | |
|                 "gpu_ops/save_with_output_msg.cc",
 | |
|                 "gpu_ops/get_output.cc",
 | |
|                 "gpu_ops/get_output_msg_with_topk.cc",
 | |
|                 "gpu_ops/save_output_msg_with_topk.cc",
 | |
|                 "gpu_ops/transfer_output.cc",
 | |
|                 "gpu_ops/get_padding_offset.cu",
 | |
|                 "gpu_ops/set_value_by_flags.cu",
 | |
|                 "gpu_ops/rebuild_padding.cu",
 | |
|                 "gpu_ops/update_inputs.cu",
 | |
|                 "gpu_ops/stop_generation_multi_ends.cu",
 | |
|                 "gpu_ops/step.cu",
 | |
|                 "gpu_ops/token_penalty_multi_scores.cu",
 | |
|                 "gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
 | |
|                 "gpu_ops/sample_kernels/top_k_renorm_probs.cu",
 | |
|                 "iluvatar_ops/moe_dispatch.cu",
 | |
|                 "iluvatar_ops/moe_reduce.cu",
 | |
|                 "iluvatar_ops/paged_attn.cu",
 | |
|                 "iluvatar_ops/w8a16_group_gemm.cu",
 | |
|                 "iluvatar_ops/runtime/iluvatar_context.cc",
 | |
|             ],
 | |
|             include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
 | |
|             extra_link_args=[
 | |
|                 "-lcuinfer",
 | |
|             ],
 | |
|         ),
 | |
|     )
 | |
| elif paddle.is_compiled_with_custom_device("gcu"):
 | |
|     setup(
 | |
|         name="fastdeploy_ops",
 | |
|         ext_modules=CppExtension(
 | |
|             sources=[
 | |
|                 "gpu_ops/save_with_output_msg.cc",
 | |
|                 "gpu_ops/get_output.cc",
 | |
|                 "gpu_ops/get_output_msg_with_topk.cc",
 | |
|             ]
 | |
|         ),
 | |
|     )
 | |
| elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
 | |
|     maca_path = os.getenv("MACA_PATH", "/opt/maca")
 | |
|     json_dir = "third_party/nlohmann_json"
 | |
|     if not os.path.exists(json_dir) or not os.listdir(json_dir):
 | |
|         if not os.path.exists(json_dir):
 | |
|             os.makedirs(json_dir)
 | |
|         clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
 | |
|         if not os.listdir(json_dir):
 | |
|             raise ValueError("Git clone nlohmann_json failed!")
 | |
|     sources = [
 | |
|         "gpu_ops/save_with_output.cc",
 | |
|         "gpu_ops/set_mask_value.cu",
 | |
|         "gpu_ops/set_value_by_flags.cu",
 | |
|         "gpu_ops/ngram_mask.cu",
 | |
|         "gpu_ops/gather_idx.cu",
 | |
|         "gpu_ops/get_output_ep.cc",
 | |
|         "gpu_ops/token_penalty_multi_scores.cu",
 | |
|         "gpu_ops/token_penalty_only_once.cu",
 | |
|         "gpu_ops/stop_generation.cu",
 | |
|         "gpu_ops/stop_generation_multi_ends.cu",
 | |
|         "gpu_ops/set_flags.cu",
 | |
|         "gpu_ops/fused_get_rope.cu",
 | |
|         "gpu_ops/get_padding_offset.cu",
 | |
|         "gpu_ops/update_inputs.cu",
 | |
|         "gpu_ops/update_inputs_beam.cu",
 | |
|         "gpu_ops/beam_search_softmax.cu",
 | |
|         "gpu_ops/rebuild_padding.cu",
 | |
|         "gpu_ops/step.cu",
 | |
|         "gpu_ops/step_reschedule.cu",
 | |
|         "gpu_ops/step_system_cache.cu",
 | |
|         "gpu_ops/set_data_ipc.cu",
 | |
|         "gpu_ops/read_data_ipc.cu",
 | |
|         "gpu_ops/dequant_int8.cu",
 | |
|         "gpu_ops/share_external_data.cu",
 | |
|         "gpu_ops/extract_text_token_output.cu",
 | |
|         "gpu_ops/moe/tritonmoe_preprocess.cu",
 | |
|         "gpu_ops/moe/moe_topk_select.cu",
 | |
|         "gpu_ops/recover_decode_task.cu",
 | |
|     ]
 | |
| 
 | |
|     sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
 | |
|     sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
 | |
| 
 | |
|     setup(
 | |
|         name="fastdeploy_ops",
 | |
|         ext_modules=CUDAExtension(
 | |
|             sources=sources,
 | |
|             extra_compile_args={
 | |
|                 "cxx": ["-O3"],
 | |
|                 "nvcc": [
 | |
|                     "-O3",
 | |
|                     "-Ithird_party/nlohmann_json/include",
 | |
|                     "-Igpu_ops",
 | |
|                     "-DPADDLE_DEV",
 | |
|                     "-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
 | |
|                 ],
 | |
|             },
 | |
|             library_dirs=[os.path.join(maca_path, "lib")],
 | |
|             extra_link_args=["-lruntime_cu"],
 | |
|             include_dirs=[
 | |
|                 os.path.join(maca_path, "include"),
 | |
|                 os.path.join(maca_path, "include/mcr"),
 | |
|                 os.path.join(maca_path, "include/common"),
 | |
|             ],
 | |
|         ),
 | |
|     )
 | |
| else:
 | |
|     use_bf16 = envs.FD_CPU_USE_BF16 == "True"
 | |
| 
 | |
|     # cc flags
 | |
|     paddle_extra_compile_args = [
 | |
|         "-std=c++17",
 | |
|         "-shared",
 | |
|         "-fPIC",
 | |
|         "-Wno-parentheses",
 | |
|         "-DPADDLE_WITH_CUSTOM_KERNEL",
 | |
|         "-DPADDLE_ON_INFERENCE",
 | |
|         "-Wall",
 | |
|         "-O3",
 | |
|         "-g",
 | |
|         "-lstdc++fs",
 | |
|         "-D_GLIBCXX_USE_CXX11_ABI=1",
 | |
|         "-DPy_LIMITED_API=0x03090000",
 | |
|     ]
 | |
| 
 | |
|     setup(
 | |
|         name="fastdeploy_cpu_ops",
 | |
|         ext_modules=CppExtension(
 | |
|             sources=[
 | |
|                 "gpu_ops/save_with_output_msg.cc",
 | |
|                 "gpu_ops/get_output.cc",
 | |
|                 "gpu_ops/get_output_msg_with_topk.cc",
 | |
|                 "gpu_ops/save_output_msg_with_topk.cc",
 | |
|                 "gpu_ops/transfer_output.cc",
 | |
|                 "cpu_ops/rebuild_padding.cc",
 | |
|                 "cpu_ops/simd_sort.cc",
 | |
|                 "cpu_ops/set_value_by_flags.cc",
 | |
|                 "cpu_ops/token_penalty_multi_scores.cc",
 | |
|                 "cpu_ops/stop_generation_multi_ends.cc",
 | |
|                 "cpu_ops/update_inputs.cc",
 | |
|                 "cpu_ops/get_padding_offset.cc",
 | |
|             ],
 | |
|             extra_link_args=[
 | |
|                 "-Wl,-rpath,$ORIGIN/x86-simd-sort/builddir",
 | |
|                 "-Wl,-rpath,$ORIGIN/xFasterTransformer/build",
 | |
|             ],
 | |
|             extra_compile_args=paddle_extra_compile_args,
 | |
|         ),
 | |
|         packages=find_namespace_packages(where="third_party"),
 | |
|         package_dir={"": "third_party"},
 | |
|         include_package_data=True,
 | |
|     )
 |