From 92c2cfa2e7f5d7b87bb070c6ec9e8a59c554f766 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun Date: Sun, 29 Jun 2025 23:29:37 +0000 Subject: [PATCH] Sync v2.0 version of code to github repo --- .clang-format | 29 + .gitignore | 6 +- .pre-commit-config.yaml | 19 +- README.md | 156 +- benchmarks/README.md | 106 + benchmarks/backend_request_func.py | 700 ++++++ benchmarks/benchmark_dataset.py | 309 +++ benchmarks/benchmark_serving.py | 1141 ++++++++++ benchmarks/benchmark_utils.py | 90 + benchmarks/requirements.txt | 5 + benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml | 8 + benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml | 5 + benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml | 8 + .../yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml | 10 + benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml | 5 + .../yaml/eb45-21b-a3b-32k-wint4-a10.yaml | 5 + benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml | 6 + benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml | 6 + benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml | 5 + .../yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml | 12 + .../eb45-32k-tensorwise-fp8-h800-tp8.yaml | 11 + benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml | 5 + .../yaml/eb45-32k-w4a8c8-tp4_decode.yaml | 15 + .../yaml/eb45-32k-w4a8c8-tp4_prefill.yaml | 12 + benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml | 6 + benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml | 5 + .../yaml/eb45-32k-wint4-h800-dp8_decode.yaml | 13 + .../yaml/eb45-32k-wint4-h800-dp8_prefill.yaml | 13 + .../yaml/eb45-32k-wint4-mtp-h800-tp4.yaml | 6 + .../yaml/eb45-32k-wint4-mtp-tp4-decode.yaml | 13 + .../yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml | 12 + benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml | 5 + benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml | 5 + .../eb45-32k-wint4-prefixcache-a800-tp4.yaml | 8 + .../yaml/eb45-32k-wint4-tp4_decode.yaml | 15 + .../yaml/eb45-32k-wint4-tp4_prefill.yaml | 12 + benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml | 5 + benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml | 5 + .../eb45-32k-wint8-prefixcache-a800-tp8.yaml | 9 + .../yaml/eb45-vl-32k-wint4-a800-tp8.yaml | 9 + .../yaml/eb45-vl-32k-wint4-h800-tp8.yaml | 11 + benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml | 9 + .../yaml/eb45-vl-32k-wint8-a800-tp8.yaml | 9 + .../yaml/eb45-vl-32k-wint8-h800-tp8.yaml | 11 + benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml | 9 + .../eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml | 5 + ...eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml | 5 + ...eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml | 6 + ...b45t_0dot3b-32k-wint8-h800-tp1-static.yaml | 6 + .../eb45t_21b-32k-bf16-h800-tp1-static.yaml | 5 + .../eb45t_21b-32k-wint4-h800-tp1-static.yaml | 6 + .../eb45t_300b-32k-wint4-h800-tp4-static.yaml | 6 + .../qwen2_7b-32k-bf16-a30-tp1-static.yaml | 5 + .../qwen2_7b-32k-bf16-h800-tp1-static.yaml | 5 + .../yaml/qwen2_7b-32k-bf16-h800-tp1.yaml | 4 + .../qwen2_7b-32k-fp8-h800-tp1-static.yaml | 6 + .../yaml/qwen2_7b-32k-fp8-h800-tp1.yaml | 5 + .../yaml/qwen2_7b-32k-wint8-h800-tp1.yaml | 5 + .../qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml | 5 + ...qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml | 5 + ...qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml | 6 + ...wen3_0dot6b-32k-wint8-h800-tp1-static.yaml | 6 + .../qwen3_30b-32k-bf16-h800-tp1-static.yaml | 5 + .../qwen3_30b-32k-wint4-h800-tp1-static.yaml | 6 + .../yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml | 5 + .../yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml | 5 + .../yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml | 5 + .../yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml | 6 + .../yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml | 6 + .../yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml | 6 + .../yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml | 6 + .../yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml | 6 + .../yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml | 5 + .../yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml | 5 + .../yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml | 6 + .../yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml | 6 + benchmarks/yaml/request_yaml/eb45-128k.yaml | 8 + benchmarks/yaml/request_yaml/eb45-32k.yaml | 8 + benchmarks/yaml/request_yaml/qwen2-32k.yaml | 8 + benchmarks/yaml/request_yaml/qwen3-32k.yaml | 8 + benchmarks/yaml/request_yaml/x1-32k.yaml | 8 + benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml | 6 + benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml | 6 + benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml | 6 + .../x1-32k-wint4-prefixcache-h800-tp8.yaml | 10 + benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml | 6 + benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml | 6 + benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml | 6 + .../x1-32k-wint8-prefixcache-h800-tp8.yaml | 10 + build.sh | 66 +- custom_ops/0001-DeepGEMM-95e81b3.patch | 643 ++++++ custom_ops/cpu_ops/avx_weight_only.cc | 188 -- custom_ops/cpu_ops/rebuild_padding.cc | 268 +++ custom_ops/cpu_ops/xft_all_layer.cc | 201 -- custom_ops/cpu_ops/xft_greedy_search.cc | 126 -- custom_ops/gpu_ops/air_topp_sampling.cu | 1612 -------------- .../get_block_shape_and_split_kv_block.cu | 321 ++- custom_ops/gpu_ops/append_attn/utils.cuh | 9 + .../{cpp_extensions.cu => cpp_extensions.cc} | 324 ++- .../arch/memory_copy_sm80.h | 250 +++ .../broadcast_load_epilogue_array_c3x.hpp | 460 ++++ .../epilogue/broadcast_load_epilogue_c2x.hpp | 500 +++++ .../epilogue/broadcast_load_epilogue_c3x.hpp | 450 ++++ .../epilogue/scaled_mm_epilogues_c2x.hpp | 327 +++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 453 ++++ .../builders/sm90_gmma_builder_gated.inl | 284 +++ .../collective/collective_builder_gated.hpp | 60 + .../gemm/collective/collective_mma_gated.hpp | 62 + ..._mma_gated_tma_gmma_ss_warpspecialized.hpp | 713 ++++++ ..._gated_tma_gmma_ss_warpspecialized_fp8.hpp | 724 +++++++ .../gemm/kernel/gemm_universal_gated.hpp | 71 + .../gemm/kernel/mixed_gemm_B_layout.h | 9 + ..._gated_tma_warpspecialized_cooperative.hpp | 705 ++++++ ...emm_gated_tma_warpspecialized_pingpong.hpp | 680 ++++++ .../gemm/threadblock/default_mma.h | 131 +- .../gemm/threadblock/default_mma_bf16.h | 146 +- .../gemm/threadblock/wint2x_mma_base.h | 237 ++ .../gemm/threadblock/wint2x_mma_multistage.h | 807 +++++++ .../gemm/threadblock/wint2x_tile_dequanter.h | 130 ++ .../gemm/threadblock/wint2x_unzip.h | 447 ++++ .../cutlass_extensions/wint_type_traits.h | 140 ++ .../gpu_ops/cutlass_kernels/cutlass_helper.h | 183 +- .../fp8_fp8_dual_gemm_scale_bias_act.h | 6 +- .../fp8_fp8_gemm_scale_bias_act.h | 5 +- .../fuse_dual_gemm_act_template_3x.h | 173 ++ .../fuse_gemm_act_template_3x.h | 151 ++ .../moe_gemm/fused_moe_cutlass_kernel.h | 422 +++- .../moe_gemm/fused_moe_gemm_kernels.h | 12 +- .../fused_moe_gemm_kernels_bf16_bf16.cu | 7 +- .../fused_moe_gemm_kernels_bf16_int2.cu | 30 + .../fused_moe_gemm_kernels_bf16_int4.cu | 6 +- .../fused_moe_gemm_kernels_bf16_int8.cu | 7 +- .../fused_moe_gemm_kernels_fp16_fp16.cu | 5 +- .../fused_moe_gemm_kernels_fp16_int2.cu | 27 + .../fused_moe_gemm_kernels_fp16_int4.cu | 5 +- .../fused_moe_gemm_kernels_fp16_int8.cu | 5 +- .../fused_moe_gemm_kernels_template.h | 302 +-- .../w4a8_moe/w4a8_moe_gemm_test.cu | 879 ++++---- .../w8a8/c3x/cutlass_gemm_caller.cuh | 102 + .../cutlass_kernels/w8a8/c3x/scaled_mm.cuh | 149 ++ .../w8a8/c3x/scaled_mm_azp_sm90_int8.cu | 27 + .../w8a8/c3x/scaled_mm_helper.hpp | 34 + .../w8a8/c3x/scaled_mm_kernels.hpp | 35 + .../w8a8/c3x/scaled_mm_sm90_fp8.cu | 28 + .../w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh | 125 ++ .../w8a8/c3x/scaled_mm_sm90_int8.cu | 29 + .../w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh | 168 ++ .../cutlass_kernels/w8a8/scaled_mm_c2x.cu | 200 ++ .../cutlass_kernels/w8a8/scaled_mm_c2x.cuh | 223 ++ .../w8a8/scaled_mm_c2x_sm75_dispatch.cuh | 125 ++ .../w8a8/scaled_mm_c2x_sm80_dispatch.cuh | 141 ++ .../w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh | 370 ++++ .../w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh | 355 +++ .../w8a8/scaled_mm_c3x_sm90.cu | 37 + .../cutlass_kernels/w8a8/scaled_mm_entry.cu | 224 ++ custom_ops/gpu_ops/fp8_deep_gemm/README.md | 27 - .../fp8_deep_gemm/deep_gemm/__init__.py | 31 - .../deep_gemm/include/deep_gemm/fp8_gemm.cuh | 462 ---- .../deep_gemm/include/deep_gemm/mma_utils.cuh | 903 -------- .../deep_gemm/include/deep_gemm/scheduler.cuh | 121 -- .../deep_gemm/include/deep_gemm/tma_utils.cuh | 116 - .../deep_gemm/include/deep_gemm/utils.cuh | 66 - .../fp8_deep_gemm/deep_gemm/jit/compiler.py | 208 -- .../deep_gemm/jit/interleave_ffma.py | 173 -- .../fp8_deep_gemm/deep_gemm/jit/runtime.py | 100 - .../fp8_deep_gemm/deep_gemm/jit/template.py | 150 -- .../deep_gemm/jit_kernels/gemm.py | 266 --- .../deep_gemm/jit_kernels/m_grouped_gemm.py | 329 --- .../deep_gemm/jit_kernels/tuner.py | 181 -- .../deep_gemm/jit_kernels/utils.py | 151 -- .../gpu_ops/fp8_deep_gemm/deep_gemm/utils.py | 137 -- custom_ops/gpu_ops/fp8_deep_gemm/setup.py | 110 - .../gpu_ops/fp8_deep_gemm/tests/test_core.py | 205 -- .../fp8_fp8_half_block_gemm.cu | 178 +- custom_ops/gpu_ops/get_data_ptr_ipc.cu | 4 +- custom_ops/gpu_ops/get_mm_split_fuse.cc | 2 + custom_ops/gpu_ops/get_output_ep.cc | 11 +- custom_ops/gpu_ops/helper.h | 618 +++--- .../ipc_sent_key_value_cache_by_remote_ptr.cu | 358 +-- custom_ops/gpu_ops/moe/deepgemm_preprocess.cu | 61 + custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu | 167 +- custom_ops/gpu_ops/moe/fused_moe.cu | 7 +- custom_ops/gpu_ops/moe/fused_moe_helper.h | 40 +- custom_ops/gpu_ops/moe/fused_moe_op.h | 575 +++-- custom_ops/gpu_ops/moe/gptq_marlin_repack.cu | 368 ++++ custom_ops/gpu_ops/moe/moe_dispatch.cu | 11 +- custom_ops/gpu_ops/moe/moe_ffn.cu | 32 +- custom_ops/gpu_ops/moe/moe_ffn_wint2.cu | 377 ++++ custom_ops/gpu_ops/moe/moe_reduce.cu | 5 +- custom_ops/gpu_ops/moe/moe_topk_select.cu | 1 + .../gpu_ops/moe/moe_wna16_marlin_gemm.cu | 1122 ++++++++++ .../gpu_ops/moe/moe_wna16_marlin_gemm.h | 37 + .../moe/moe_wna16_marlin_utils/CUDAStream.h | 63 + .../moe/moe_wna16_marlin_utils/ScalarType.h | 372 ++++ .../moe/moe_wna16_marlin_utils/dequant.h | 508 +++++ .../moe/moe_wna16_marlin_utils/kernel.h | 42 + .../moe_wna16_marlin_utils/kernel_bf16_ku4.cu | 89 + .../kernel_bf16_ku4b8.cu | 89 + .../moe_wna16_marlin_utils/kernel_fp16_ku4.cu | 89 + .../kernel_fp16_ku4b8.cu | 109 + .../moe/moe_wna16_marlin_utils/marlin.cuh | 99 + .../moe_wna16_marlin_utils/marlin_dtypes.cuh | 83 + .../moe_wna16_marlin_utils/marlin_template.h | 1927 +++++++++++++++++ .../moe/moe_wna16_marlin_utils/types.h | 85 + .../gpu_ops/moe/tritonmoe_preprocess.cu | 198 ++ custom_ops/gpu_ops/moe/wintx_unzip.cu | 316 +++ .../gpu_ops/open_shm_and_get_meta_signal.cc | 14 +- custom_ops/gpu_ops/quantization/common.cu | 235 ++ custom_ops/gpu_ops/quantization/common.cuh | 159 ++ custom_ops/gpu_ops/remote_cache_kv_ipc.cc | 13 +- custom_ops/gpu_ops/remote_cache_kv_ipc.h | 1 + custom_ops/gpu_ops/reset_need_stop_value.cc | 30 - .../sample_kernels/air_top_p_sampling.cu | 1469 +++++++++++++ .../rejection_top_p_sampling.cu | 73 + .../gpu_ops/sample_kernels/sampling.cuh | 559 +++++ custom_ops/gpu_ops/sample_kernels/utils.cuh | 269 +++ custom_ops/gpu_ops/save_with_output_msg.cc | 4 + .../draft_model/draft_model_preprocess.cu | 65 +- .../draft_model/mtp_save_first_token.cc | 4 +- .../gpu_ops/speculate_decoding/ngram_match.cc | 46 +- .../speculate_get_output.cc | 2 +- .../speculate_save_output.cc | 2 +- .../speculate_decoding/speculate_verify.cu | 712 +++--- custom_ops/gpu_ops/step_reschedule.cu | 6 +- .../gpu_ops/text_image_gather_scatter.cu | 233 ++ custom_ops/gpu_ops/text_image_index_out.cu | 64 + custom_ops/setup_ops.py | 303 +-- custom_ops/setup_ops_base.py | 8 +- custom_ops/setup_ops_cpu.py | 105 +- ...n_fp8_fp8_block_gemm_fused_kernels_sm90.py | 134 +- ...uto_gen_fp8_fp8_dual_gemm_fused_kernels.py | 109 +- ...en_fp8_fp8_dual_gemm_fused_kernels_sm90.py | 592 +++++ .../auto_gen_fp8_fp8_gemm_fused_kernels.py | 160 +- ...uto_gen_fp8_fp8_gemm_fused_kernels_sm90.py | 614 ++++++ ...auto_gen_visitor_fp8_gemm_fused_kernels.py | 64 +- custom_ops/xpu_ops/src/build.sh | 33 + custom_ops/xpu_ops/src/ops/adjust_batch.cc | 166 ++ custom_ops/xpu_ops/src/ops/block_attn.cc | 320 +++ .../device/get_context_gm_max_mem_demand.cc | 58 + .../src/ops/device/get_free_global_memory.cc | 61 + .../src/ops/device/get_total_global_memory.cc | 60 + .../src/ops/device/get_used_global_memory.cc | 60 + .../xpu_ops/src/ops/gather_next_token.cc | 108 + custom_ops/xpu_ops/src/ops/get_infer_param.cc | 244 +++ custom_ops/xpu_ops/src/ops/get_output.cc | 120 + .../xpu_ops/src/ops/get_padding_offset.cc | 94 + .../src/ops/get_token_penalty_multi_scores.cc | 81 + custom_ops/xpu_ops/src/ops/moe_layer.cc | 274 +++ .../xpu_ops/src/ops/save_with_output_msg.cc | 137 ++ .../src/ops/set_value_by_flags_and_idx.cc | 48 + custom_ops/xpu_ops/src/ops/step.cc | 142 ++ .../src/ops/stop_generation_multi_ends.cc | 59 + custom_ops/xpu_ops/src/ops/update_inputs.cc | 85 + custom_ops/xpu_ops/src/ops/utility/helper.h | 70 + .../xpu_ops/src/ops/weight_quantize_xpu.cc | 126 ++ custom_ops/xpu_ops/src/ops/xpu_multiprocess.h | 83 + custom_ops/xpu_ops/src/plugin/CMakeLists.txt | 408 ++++ custom_ops/xpu_ops/src/plugin/README.md | 26 + .../xpu_ops/src/plugin/build.sh | 29 +- .../xpu_ops/src/plugin/include/xpu/plugin.h | 112 + .../src/kernel/kunlun3cpp/ban_bad_words.xpu | 58 + .../src/kernel/kunlun3cpp/eb_adjust_batch.xpu | 129 ++ .../kunlun3cpp/eb_gather_next_token.xpu | 102 + .../kunlun3cpp/free_and_dispatch_block.xpu | 326 +++ .../kernel/kunlun3cpp/get_padding_offset.xpu | 53 + .../kunlun3cpp/min_length_logits_process.xpu | 68 + .../kernel/kunlun3cpp/quant2d_per_channel.xpu | 1069 +++++++++ .../src/kernel/kunlun3cpp/recover_block.xpu | 154 ++ .../src/kernel/kunlun3cpp/remove_padding.xpu | 40 + .../kunlun3cpp/set_stop_value_multi_ends.xpu | 102 + .../kunlun3cpp/set_value_by_flags_and_idx.xpu | 50 + .../src/kernel/kunlun3cpp/update_inputs.xpu | 77 + .../kernel/kunlun3cpp/update_repeat_times.xpu | 75 + .../update_value_by_repeat_times.xpu | 211 ++ .../xpu_ops/src/plugin/src/linker.specs | 6 + .../plugin/src/wrapper/eb_adjust_batch.cpp | 169 ++ .../src/wrapper/eb_gather_next_token.cpp | 139 ++ .../src/wrapper/free_and_dispatch_block.cpp | 222 ++ .../plugin/src/wrapper/get_padding_offset.cpp | 131 ++ .../wrapper/nn_set_stop_value_multi_ends.cpp | 139 ++ .../wrapper/nn_set_value_by_flags_and_idx.cpp | 139 ++ .../wrapper/nn_token_penalty_multi_scores.cpp | 274 +++ .../src/wrapper/quant2d_per_channel.cpp | 280 +++ .../src/plugin/src/wrapper/recover_block.cpp | 165 ++ .../src/plugin/src/wrapper/update_inputs.cpp | 119 + custom_ops/xpu_ops/src/setup_ops.py | 199 ++ .../python/ops/test_get_padding_offset.py | 64 + .../test_get_token_penalty_multi_scores.py | 254 +++ .../ops/test_set_value_by_flags_and_idx.py | 76 + .../xpu_ops/test/python/ops/test_step.py | 170 ++ .../ops/test_stop_generation_multi_ends.py | 139 ++ .../ops/test_token_repetition_penalty.py | 45 + .../test/python/ops/test_update_inputs.py | 106 + .../python/ops/test_weight_quantize_xpu.py | 94 + dockerfiles/Dockerfile.gpu | 28 + dockerfiles/Dockerfile.xpu | 43 + docs/benchmark.md | 40 + docs/code_guide.md | 22 - docs/features/chunked_prefill.md | 25 + docs/features/disaggregated.md | 166 ++ docs/features/images/GlobalScheduler.png | Bin 0 -> 454300 bytes docs/features/images/LocalScheduler.png | Bin 0 -> 336591 bytes docs/features/images/disaggregated.png | Bin 0 -> 308510 bytes docs/features/load_balance.md | 81 + docs/features/prefix_caching.md | 39 + docs/features/reasoning_output.md | 67 + docs/features/speculative_decoding.md | 150 ++ docs/features/structured_outputs.md | 332 +++ docs/get_started/ernie-4.5-vl.md | 199 ++ docs/get_started/ernie-4.5.md | 89 + docs/get_started/installation/Enflame_gcu.md | 129 ++ docs/get_started/installation/README.md | 9 +- docs/get_started/installation/iluvatar_gpu.md | 101 + .../get_started/installation/kunlunxin_xpu.md | 221 ++ docs/get_started/installation/nvidia_gpu.md | 89 + docs/get_started/quick_start.md | 93 + docs/get_started/quick_start_vl.md | 106 + docs/index.md | 37 + docs/metrics.md | 20 - docs/offline_inference.md | 150 +- docs/online_serving/README.md | 97 + docs/online_serving/metrics.md | 27 + docs/online_serving/scheduler.md | 39 + docs/parameters.md | 229 +- docs/quantization/README.md | 46 + docs/quantization/online_quantization.md | 54 + docs/quantization/wint2.md | 59 + docs/requirements.txt | 5 + docs/serving.md | 140 -- docs/supported_models.md | 35 + docs/usage/code_overview.md | 25 + docs/usage/environment_variables.md | 72 + docs/usage/log.md | 38 + docs/zh/benchmark.md | 40 + docs/zh/features/chunked_prefill.md | 25 + docs/zh/features/disaggregated.md | 173 ++ docs/zh/features/images/GlobalScheduler.png | Bin 0 -> 454300 bytes docs/zh/features/images/LocalScheduler.png | Bin 0 -> 336591 bytes docs/zh/features/images/disaggregated.png | Bin 0 -> 308510 bytes docs/zh/features/load_balance.md | 69 + docs/zh/features/prefix_caching.md | 40 + docs/zh/features/reasoning_output.md | 76 + docs/zh/features/speculative_decoding.md | 120 + docs/zh/features/structured_outputs.md | 332 +++ docs/zh/get_started/ernie-4.5-vl.md | 201 ++ docs/zh/get_started/ernie-4.5.md | 85 + .../get_started/installation/Enflame_gcu.md | 128 ++ docs/zh/get_started/installation/README.md | 8 + .../get_started/installation/iluvatar_gpu.md | 102 + .../get_started/installation/kunlunxin_xpu.md | 226 ++ .../zh/get_started/installation/nvidia_gpu.md | 87 + docs/zh/get_started/quick_start.md | 85 + docs/zh/get_started/quick_start_vl.md | 103 + docs/zh/index.md | 35 + docs/zh/offline_inference.md | 133 ++ docs/zh/online_serving/README.md | 97 + docs/zh/online_serving/metrics.md | 27 + docs/zh/online_serving/scheduler.md | 41 + docs/zh/parameters.md | 117 + docs/zh/quantization/README.md | 47 + docs/zh/quantization/online_quantization.md | 57 + docs/zh/quantization/wint2.md | 62 + docs/zh/supported_models.md | 36 + docs/zh/usage/code_overview.md | 25 + docs/zh/usage/environment_variables.md | 70 + docs/zh/usage/log.md | 44 + fastdeploy/__init__.py | 13 +- .../__init__.py} | 2 +- fastdeploy/cache_manager/cache_data.py | 162 ++ fastdeploy/cache_manager/cache_messager.py | 318 +++ fastdeploy/cache_manager/cache_metrics.py | 137 ++ .../cache_manager/cache_transfer_manager.py | 470 ++++ .../cache_manager/prefix_cache_manager.py | 1033 +++++++++ .../transfer_factory}/__init__.py | 12 +- .../transfer_factory/ipc_cache_transfer.py | 133 ++ .../kvcache_transfer/CMakeLists.txt | 35 + .../kvcache_transfer/README.md | 232 ++ .../kvcache_transfer/README_CN.md | 232 ++ .../include/kvcache_connection.h | 211 ++ .../kvcache_transfer/include/kvcache_rdma.h | 127 ++ .../kvcache_transfer/include/log.h | 117 + .../kvcache_transfer/include/util.h | 315 +++ .../src/kvcache_connection.cpp | 1050 +++++++++ .../kvcache_transfer/src/kvcache_rdma.cpp | 1056 +++++++++ .../kvcache_transfer/src/log.cpp | 212 ++ .../kvcache_transfer/src/pybind.cpp | 22 + .../transfer_factory/rdma_cache_transfer.py | 76 + fastdeploy/config.py | 392 ++-- fastdeploy/demo/offline_demo.py | 29 + fastdeploy/demo/offline_disaggregated_demo.py | 63 + .../demo/offline_prefix_caching_demo.py | 56 + fastdeploy/demo/openai_demo.py | 82 + fastdeploy/demo/openai_vl_demo.py | 98 + fastdeploy/distributed/communication_op.py | 10 +- fastdeploy/download_model.py | 227 ++ fastdeploy/engine/args_utils.py | 443 +++- fastdeploy/engine/config.py | 686 ++++-- fastdeploy/engine/engine.py | 763 +++++-- fastdeploy/engine/expert_service.py | 370 ++++ fastdeploy/engine/request.py | 262 ++- fastdeploy/engine/resource_manager.py | 372 +++- fastdeploy/engine/sampling_params.py | 184 +- fastdeploy/entrypoints/api_server.py | 49 +- fastdeploy/entrypoints/chat_utils.py | 117 +- fastdeploy/entrypoints/engine_client.py | 31 +- fastdeploy/entrypoints/llm.py | 262 ++- fastdeploy/entrypoints/openai/api_server.py | 303 +-- fastdeploy/entrypoints/openai/protocol.py | 394 ++-- fastdeploy/entrypoints/openai/serving_chat.py | 171 +- .../entrypoints/openai/serving_completion.py | 175 +- fastdeploy/entrypoints/openai/test_openai.py | 82 + fastdeploy/envs.py | 103 + fastdeploy/import_ops.py | 32 +- fastdeploy/inference_args.py | 628 ------ fastdeploy/input/ernie_processor.py | 444 ++++ fastdeploy/input/ernie_tokenizer.py | 455 ++-- fastdeploy/input/ernie_vl_processor.py | 260 +++ .../image_preprocessor_adaptive.py | 12 +- fastdeploy/input/mm_processor/process.py | 197 +- .../mm_processor/tokenizer/tokenizer_vl.py | 69 +- fastdeploy/input/multimodal/image.py | 4 +- fastdeploy/input/multimodal/video.py | 144 +- fastdeploy/input/preprocess.py | 44 +- fastdeploy/input/text_processor.py | 235 +- fastdeploy/inter_communicator.py | 546 ----- fastdeploy/inter_communicator/__init__.py | 25 + .../inter_communicator/engine_cache_queue.py | 310 +++ .../inter_communicator/engine_worker_queue.py | 416 ++++ fastdeploy/inter_communicator/ipc_signal.py | 96 + fastdeploy/inter_communicator/zmq_client.py | 196 ++ fastdeploy/metrics/__init__.py | 31 +- fastdeploy/metrics/metrics.py | 62 +- fastdeploy/metrics/work_metrics.py | 18 +- .../cudagraph_piecewise_backend.py | 57 +- .../graph_optimization/decorator.py | 20 +- .../graph_optimization_backend.py | 51 +- .../guided_decoding/__init__.py | 73 + .../guided_decoding/base_guided_decoding.py | 347 +++ .../guided_decoding/ernie_tokenizer.py | 266 +++ .../guided_decoding/xgrammar_backend.py | 457 ++++ .../model_executor/layers/activation.py | 43 +- .../layers/attention/__init__.py | 12 +- .../layers/attention/append_attn_backend.py | 126 +- .../layers/attention/attention.py | 87 +- .../layers/attention/attention_selecter.py | 24 +- .../model_executor/layers/attention/base.py | 395 ---- .../attention/base_attention_backend.py | 16 +- .../layers/attention/native_paddle_backend.py | 103 +- .../layers/attention/ops/__init__.py | 11 +- .../layers/attention/ops/append_attention.py | 16 +- .../attention/ops/init_signal_layerwise.py | 37 +- .../ops/open_shm_and_get_meta_signal.py | 35 + .../layers/attention/xpu_attn_backend.py | 188 ++ .../layers/backends/xpu/__init__.py | 4 +- .../backends/xpu/quantization}/__init__.py | 9 +- .../backends/xpu/quantization/weight_only.py | 146 +- .../layers/backends/xpu/utils.py | 10 +- .../model_executor/layers/embeddings.py | 67 +- .../model_executor/layers/hydra_head.py | 2 +- fastdeploy/model_executor/layers/linear.py | 685 +++--- fastdeploy/model_executor/layers/lm_head.py | 158 +- .../model_executor/layers/moe/__init__.py | 10 + .../layers/moe/cutlass_fused_moe.py | 222 -- fastdeploy/model_executor/layers/moe/ep.py | 1183 ++-------- .../layers/moe/fused_moe_backend_base.py | 135 ++ .../layers/moe/fused_moe_cutlass_backend.py | 431 ++++ .../layers/moe/fused_moe_deepgemm_backend.py | 380 ++++ .../layers/moe/fused_moe_marlin_backend.py | 285 +++ .../layers/moe/fused_moe_method_base.py | 57 - .../layers/moe/fused_moe_triton_backend.py | 479 ++++ .../layers/moe/fused_moe_wint2_backend.py | 236 ++ fastdeploy/model_executor/layers/moe/mm.py | 273 --- fastdeploy/model_executor/layers/moe/moe.py | 316 ++- fastdeploy/model_executor/layers/moe/tp.py | 126 -- .../layers/moe/triton_moe_kernels.py | 198 ++ .../model_executor/layers/normalization.py | 37 +- .../layers/quantization/__init__.py | 31 +- .../{block_wise.py => block_wise_fp8.py} | 66 +- .../layers/quantization/kv_cache.py | 263 +-- .../layers/quantization/mix_quant.py | 75 + .../layers/quantization/ops}/__init__.py | 15 +- .../quantization/ops/cutlass_scaled_mm.py | 126 ++ .../quantization/ops/scaled_fp8_quant.py | 75 + .../layers/quantization/quant_base.py | 5 +- .../layers/quantization/tensor_wise_fp8.py | 135 ++ .../layers/quantization/w4a8.py | 42 + .../layers/quantization/w4afp8.py | 21 +- .../layers/quantization/w8a8.py | 78 +- .../layers/quantization/weight_only.py | 103 +- .../layers/quantization/wfp8afp8.py | 96 +- .../layers/quantization/wint2.py | 142 ++ .../model_executor/layers/rotary_embedding.py | 90 +- .../model_executor/layers/sample/meta_data.py | 3 +- .../layers/sample/ops/__init__.py | 6 +- .../sample/ops/apply_penalty_multi_scores.py | 69 +- .../layers/sample/ops/top_p_sampling.py | 97 + .../model_executor/layers/sample/sampler.py | 320 ++- fastdeploy/model_executor/layers/utils.py | 251 ++- fastdeploy/model_executor/model_loader.py | 51 +- fastdeploy/model_executor/models/__init__.py | 50 +- .../model_executor/models/ernie4_5_moe.py | 774 +++++++ .../model_executor/models/ernie4_5_mtp.py | 417 ++++ .../models/ernie4_5_vl}/__init__.py | 0 .../models/ernie4_5_vl/configuration.py | 167 ++ .../models/ernie4_5_vl/dfnrope/__init__.py | 22 + .../models/ernie4_5_vl/dfnrope/activation.py | 287 +++ .../ernie4_5_vl/dfnrope/configuration.py | 70 + .../models/ernie4_5_vl/dfnrope/modeling.py | 732 +++++++ .../models/ernie4_5_vl/dist_utils.py | 130 ++ .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 511 +++++ .../models/ernie4_5_vl/modeling_resampler.py | 399 ++++ .../model_executor/models/export_model.py | 652 ------ fastdeploy/model_executor/models/qwen2.py | 108 +- fastdeploy/model_executor/models/qwen3.py | 361 +++ fastdeploy/model_executor/models/qwen3moe.py | 509 +++++ fastdeploy/model_executor/models/tokenizer.py | 382 ---- fastdeploy/model_executor/models/utils.py | 972 ++------- .../model_executor/ops/triton_ops/__init__.py | 22 + .../ops/triton_ops/triton_utils.py | 804 +++++++ .../ops/triton_ops/wint2_fused_moe.py | 549 +++++ .../model_executor/pre_and_post_process.py | 343 ++- fastdeploy/output/token_processor.py | 331 ++- fastdeploy/platforms/cuda.py | 25 +- fastdeploy/platforms/xpu.py | 40 +- fastdeploy/reasoning/__init__.py | 25 + fastdeploy/reasoning/abs_reasoning_parsers.py | 188 ++ .../reasoning/ernie_vl_reasoning_parsers.py | 106 + .../reasoning/qwen3_reasoning_parsers.py | 145 ++ fastdeploy/scheduler/config.py | 185 +- fastdeploy/scheduler/data.py | 179 +- fastdeploy/scheduler/global_scheduler.py | 843 +++++-- fastdeploy/scheduler/local_scheduler.py | 270 ++- fastdeploy/scheduler/splitwise_scheduler.py | 835 +++++++ fastdeploy/scheduler/storage.py | 161 +- fastdeploy/scheduler/utils.py | 35 + fastdeploy/scheduler/workers.py | 204 +- fastdeploy/spec_decode/__init__.py | 22 + fastdeploy/spec_decode/base.py | 63 + fastdeploy/spec_decode/mtp.py | 629 ++++++ fastdeploy/spec_decode/ngram.py | 69 + .../xpu_worker.py => splitwise/__init__.py} | 2 +- fastdeploy/splitwise/splitwise_connector.py | 481 ++++ fastdeploy/start_splitwise.sh | 15 + fastdeploy/stop.sh | 21 + fastdeploy/test.yaml | 4 + fastdeploy/utils.py | 349 ++- fastdeploy/worker/V1/gpu_model_runner.py | 545 ----- fastdeploy/worker/V1/worker_process.py | 318 --- .../{model_executor/eplb => worker}/eplb.py | 116 +- .../eplb => worker}/experts_manager.py | 107 +- .../worker/{model_runner => }/forward_meta.py | 207 +- fastdeploy/worker/gpu_model_runner.py | 1203 ++++++++++ fastdeploy/worker/{V1 => }/gpu_worker.py | 124 +- .../model_runner/model_runner_inference.py | 509 ----- .../model_runner/model_runner_minimal_os.py | 74 - .../model_runner/model_runner_paddlenlp.py | 318 --- .../worker/{V1 => }/model_runner_base.py | 41 +- fastdeploy/worker/output.py | 126 +- fastdeploy/worker/utils.py | 240 +- fastdeploy/worker/vl_gpu_model_runner.py | 1204 ++++++++++ ...runner_base.py => vl_model_runner_base.py} | 197 +- .../{worker.py => vl_worker_process.py} | 310 ++- fastdeploy/worker/{V1 => }/worker_base.py | 36 +- fastdeploy/worker/worker_process.py | 772 +++++++ fastdeploy/worker/xpu_model_runner.py | 819 +++++++ fastdeploy/worker/xpu_worker.py | 165 ++ mkdocs.yml | 47 + requirements.txt | 10 +- scripts/build_wheel_pipeline_cu123.sh | 40 +- scripts/codestyle/clang-tidy.py | 497 ----- scripts/codestyle/clang_format.sh | 35 - scripts/codestyle/copyright.py | 132 -- scripts/codestyle/pre_commit.sh | 68 - scripts/convert_ep_to_safetensor.py | 252 +++ scripts/extract_mtp_weight_from_safetensor.py | 122 ++ scripts/prefill_fake_server.sh | 78 - scripts/run_ci.sh | 37 - scripts/run_offline_quantization.sh | 38 - .../run_prediction_ep_decoder_multi_node.sh | 22 + ...n_prediction_ep_decoder_multi_node_perf.sh | 28 + ..._prediction_ep_decoder_single_node_perf.sh | 22 + ...t.sh => run_prediction_ep_prefill_perf.sh} | 14 +- scripts/run_unittest.sh | 4 +- scripts/vit_model_split.py | 67 + scripts/vit_model_split.sh | 15 + setup.py | 184 +- test/ci_use/test_qwen2_offline.py | 167 -- test/ci_use/test_qwen2_serving.py | 491 ----- test/layers/test_append_attention.py | 624 ++++++ test/layers/test_attention.py | 174 +- test/layers/test_sampler.py | 8 +- test/operators/test_air_topp_sampling.py | 35 +- test/operators/test_cutlass_scaled_mm.py | 101 + .../test_deqant_int8_cpp_extension.py | 22 +- .../test_rejection_top_p_sampling.py | 66 + test/worker/test_cuda_graph.py | 132 +- tools/dockerfile/Dockerfile.ci | 5 - 597 files changed, 78776 insertions(+), 22905 deletions(-) create mode 100644 .clang-format create mode 100644 benchmarks/README.md create mode 100644 benchmarks/backend_request_func.py create mode 100644 benchmarks/benchmark_dataset.py create mode 100644 benchmarks/benchmark_serving.py create mode 100644 benchmarks/benchmark_utils.py create mode 100644 benchmarks/requirements.txt create mode 100644 benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml create mode 100644 benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml create mode 100644 benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml create mode 100644 benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml create mode 100644 benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml create mode 100644 benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml create mode 100644 benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml create mode 100644 benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml create mode 100644 benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml create mode 100644 benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml create mode 100644 benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml create mode 100644 benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml create mode 100644 benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml create mode 100644 benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml create mode 100644 benchmarks/yaml/request_yaml/eb45-128k.yaml create mode 100644 benchmarks/yaml/request_yaml/eb45-32k.yaml create mode 100644 benchmarks/yaml/request_yaml/qwen2-32k.yaml create mode 100644 benchmarks/yaml/request_yaml/qwen3-32k.yaml create mode 100644 benchmarks/yaml/request_yaml/x1-32k.yaml create mode 100644 benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml create mode 100644 benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml create mode 100644 benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml create mode 100644 benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml create mode 100644 benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml create mode 100644 benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml create mode 100644 benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml create mode 100644 benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml create mode 100644 custom_ops/0001-DeepGEMM-95e81b3.patch delete mode 100644 custom_ops/cpu_ops/avx_weight_only.cc create mode 100644 custom_ops/cpu_ops/rebuild_padding.cc delete mode 100644 custom_ops/cpu_ops/xft_all_layer.cc delete mode 100644 custom_ops/cpu_ops/xft_greedy_search.cc delete mode 100644 custom_ops/gpu_ops/air_topp_sampling.cu rename custom_ops/gpu_ops/{cpp_extensions.cu => cpp_extensions.cc} (61%) create mode 100644 custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h create mode 100644 custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h create mode 100644 custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h create mode 100644 custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h create mode 100644 custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h create mode 100644 custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h create mode 100644 custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/cutlass_gemm_caller.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x_sm75_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x_sm80_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu create mode 100644 custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_entry.cu delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/README.md delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/__init__.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/deep_gemm/fp8_gemm.cuh delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/deep_gemm/mma_utils.cuh delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/deep_gemm/scheduler.cuh delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/deep_gemm/tma_utils.cuh delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/deep_gemm/utils.cuh delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit/compiler.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit/interleave_ffma.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit/runtime.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit/template.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit_kernels/gemm.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit_kernels/m_grouped_gemm.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit_kernels/tuner.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit_kernels/utils.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/utils.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/setup.py delete mode 100644 custom_ops/gpu_ops/fp8_deep_gemm/tests/test_core.py create mode 100644 custom_ops/gpu_ops/moe/deepgemm_preprocess.cu create mode 100644 custom_ops/gpu_ops/moe/gptq_marlin_repack.cu create mode 100644 custom_ops/gpu_ops/moe/moe_ffn_wint2.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/ScalarType.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/dequant.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_bf16_ku4b8.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_fp16_ku4b8.cu create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin.cuh create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_dtypes.cuh create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/marlin_template.h create mode 100644 custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/types.h create mode 100644 custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu create mode 100644 custom_ops/gpu_ops/moe/wintx_unzip.cu create mode 100644 custom_ops/gpu_ops/quantization/common.cu create mode 100644 custom_ops/gpu_ops/quantization/common.cuh delete mode 100644 custom_ops/gpu_ops/reset_need_stop_value.cc create mode 100644 custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu create mode 100644 custom_ops/gpu_ops/sample_kernels/rejection_top_p_sampling.cu create mode 100644 custom_ops/gpu_ops/sample_kernels/sampling.cuh create mode 100644 custom_ops/gpu_ops/sample_kernels/utils.cuh create mode 100644 custom_ops/gpu_ops/text_image_gather_scatter.cu create mode 100644 custom_ops/gpu_ops/text_image_index_out.cu rename custom_ops/{ => utils}/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py (85%) rename custom_ops/{ => utils}/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py (88%) create mode 100644 custom_ops/utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py rename custom_ops/{ => utils}/auto_gen_fp8_fp8_gemm_fused_kernels.py (82%) create mode 100644 custom_ops/utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py rename custom_ops/{ => utils}/auto_gen_visitor_fp8_gemm_fused_kernels.py (91%) create mode 100755 custom_ops/xpu_ops/src/build.sh create mode 100644 custom_ops/xpu_ops/src/ops/adjust_batch.cc create mode 100644 custom_ops/xpu_ops/src/ops/block_attn.cc create mode 100644 custom_ops/xpu_ops/src/ops/device/get_context_gm_max_mem_demand.cc create mode 100644 custom_ops/xpu_ops/src/ops/device/get_free_global_memory.cc create mode 100644 custom_ops/xpu_ops/src/ops/device/get_total_global_memory.cc create mode 100644 custom_ops/xpu_ops/src/ops/device/get_used_global_memory.cc create mode 100644 custom_ops/xpu_ops/src/ops/gather_next_token.cc create mode 100644 custom_ops/xpu_ops/src/ops/get_infer_param.cc create mode 100644 custom_ops/xpu_ops/src/ops/get_output.cc create mode 100644 custom_ops/xpu_ops/src/ops/get_padding_offset.cc create mode 100644 custom_ops/xpu_ops/src/ops/get_token_penalty_multi_scores.cc create mode 100644 custom_ops/xpu_ops/src/ops/moe_layer.cc create mode 100644 custom_ops/xpu_ops/src/ops/save_with_output_msg.cc create mode 100644 custom_ops/xpu_ops/src/ops/set_value_by_flags_and_idx.cc create mode 100644 custom_ops/xpu_ops/src/ops/step.cc create mode 100644 custom_ops/xpu_ops/src/ops/stop_generation_multi_ends.cc create mode 100644 custom_ops/xpu_ops/src/ops/update_inputs.cc create mode 100644 custom_ops/xpu_ops/src/ops/utility/helper.h create mode 100644 custom_ops/xpu_ops/src/ops/weight_quantize_xpu.cc create mode 100644 custom_ops/xpu_ops/src/ops/xpu_multiprocess.h create mode 100644 custom_ops/xpu_ops/src/plugin/CMakeLists.txt create mode 100644 custom_ops/xpu_ops/src/plugin/README.md rename scripts/codestyle/sort_txt_file.py => custom_ops/xpu_ops/src/plugin/build.sh (54%) mode change 100644 => 100755 create mode 100644 custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/ban_bad_words.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_gather_next_token.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/free_and_dispatch_block.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/min_length_logits_process.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/quant2d_per_channel.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_block.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/remove_padding.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_stop_value_multi_ends.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_value_by_flags_and_idx.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_repeat_times.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_value_by_repeat_times.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/linker.specs create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/free_and_dispatch_block.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/quant2d_per_channel.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/recover_block.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs.cpp create mode 100755 custom_ops/xpu_ops/src/setup_ops.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_get_padding_offset.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_get_token_penalty_multi_scores.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_set_value_by_flags_and_idx.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_step.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_stop_generation_multi_ends.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_token_repetition_penalty.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_update_inputs.py create mode 100644 custom_ops/xpu_ops/test/python/ops/test_weight_quantize_xpu.py create mode 100644 dockerfiles/Dockerfile.gpu create mode 100644 dockerfiles/Dockerfile.xpu create mode 100644 docs/benchmark.md delete mode 100644 docs/code_guide.md create mode 100644 docs/features/chunked_prefill.md create mode 100644 docs/features/disaggregated.md create mode 100644 docs/features/images/GlobalScheduler.png create mode 100644 docs/features/images/LocalScheduler.png create mode 100644 docs/features/images/disaggregated.png create mode 100644 docs/features/load_balance.md create mode 100644 docs/features/prefix_caching.md create mode 100644 docs/features/reasoning_output.md create mode 100644 docs/features/speculative_decoding.md create mode 100644 docs/features/structured_outputs.md create mode 100644 docs/get_started/ernie-4.5-vl.md create mode 100644 docs/get_started/ernie-4.5.md create mode 100644 docs/get_started/installation/Enflame_gcu.md create mode 100644 docs/get_started/installation/iluvatar_gpu.md create mode 100644 docs/get_started/installation/kunlunxin_xpu.md create mode 100644 docs/get_started/installation/nvidia_gpu.md create mode 100644 docs/get_started/quick_start.md create mode 100644 docs/get_started/quick_start_vl.md create mode 100644 docs/index.md delete mode 100644 docs/metrics.md create mode 100644 docs/online_serving/README.md create mode 100644 docs/online_serving/metrics.md create mode 100644 docs/online_serving/scheduler.md create mode 100644 docs/quantization/README.md create mode 100644 docs/quantization/online_quantization.md create mode 100644 docs/quantization/wint2.md create mode 100644 docs/requirements.txt delete mode 100644 docs/serving.md create mode 100644 docs/supported_models.md create mode 100644 docs/usage/code_overview.md create mode 100644 docs/usage/environment_variables.md create mode 100644 docs/usage/log.md create mode 100644 docs/zh/benchmark.md create mode 100644 docs/zh/features/chunked_prefill.md create mode 100644 docs/zh/features/disaggregated.md create mode 100644 docs/zh/features/images/GlobalScheduler.png create mode 100644 docs/zh/features/images/LocalScheduler.png create mode 100644 docs/zh/features/images/disaggregated.png create mode 100644 docs/zh/features/load_balance.md create mode 100644 docs/zh/features/prefix_caching.md create mode 100644 docs/zh/features/reasoning_output.md create mode 100644 docs/zh/features/speculative_decoding.md create mode 100644 docs/zh/features/structured_outputs.md create mode 100644 docs/zh/get_started/ernie-4.5-vl.md create mode 100644 docs/zh/get_started/ernie-4.5.md create mode 100644 docs/zh/get_started/installation/Enflame_gcu.md create mode 100644 docs/zh/get_started/installation/README.md create mode 100644 docs/zh/get_started/installation/iluvatar_gpu.md create mode 100644 docs/zh/get_started/installation/kunlunxin_xpu.md create mode 100644 docs/zh/get_started/installation/nvidia_gpu.md create mode 100644 docs/zh/get_started/quick_start.md create mode 100644 docs/zh/get_started/quick_start_vl.md create mode 100644 docs/zh/index.md create mode 100644 docs/zh/offline_inference.md create mode 100644 docs/zh/online_serving/README.md create mode 100644 docs/zh/online_serving/metrics.md create mode 100644 docs/zh/online_serving/scheduler.md create mode 100644 docs/zh/parameters.md create mode 100644 docs/zh/quantization/README.md create mode 100644 docs/zh/quantization/online_quantization.md create mode 100644 docs/zh/quantization/wint2.md create mode 100644 docs/zh/supported_models.md create mode 100644 docs/zh/usage/code_overview.md create mode 100644 docs/zh/usage/environment_variables.md create mode 100644 docs/zh/usage/log.md rename fastdeploy/{worker/V1/xpu_model_runner.py => cache_manager/__init__.py} (99%) create mode 100644 fastdeploy/cache_manager/cache_data.py create mode 100644 fastdeploy/cache_manager/cache_messager.py create mode 100644 fastdeploy/cache_manager/cache_metrics.py create mode 100644 fastdeploy/cache_manager/cache_transfer_manager.py create mode 100644 fastdeploy/cache_manager/prefix_cache_manager.py rename fastdeploy/{worker/model_runner => cache_manager/transfer_factory}/__init__.py (74%) create mode 100644 fastdeploy/cache_manager/transfer_factory/ipc_cache_transfer.py create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/CMakeLists.txt create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/README.md create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/README_CN.md create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/log.h create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/util.h create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/log.cpp create mode 100644 fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp create mode 100644 fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py create mode 100644 fastdeploy/demo/offline_demo.py create mode 100644 fastdeploy/demo/offline_disaggregated_demo.py create mode 100644 fastdeploy/demo/offline_prefix_caching_demo.py create mode 100644 fastdeploy/demo/openai_demo.py create mode 100644 fastdeploy/demo/openai_vl_demo.py create mode 100644 fastdeploy/download_model.py create mode 100644 fastdeploy/engine/expert_service.py create mode 100644 fastdeploy/entrypoints/openai/test_openai.py create mode 100644 fastdeploy/envs.py delete mode 100644 fastdeploy/inference_args.py create mode 100644 fastdeploy/input/ernie_processor.py create mode 100644 fastdeploy/input/ernie_vl_processor.py delete mode 100644 fastdeploy/inter_communicator.py create mode 100644 fastdeploy/inter_communicator/__init__.py create mode 100644 fastdeploy/inter_communicator/engine_cache_queue.py create mode 100644 fastdeploy/inter_communicator/engine_worker_queue.py create mode 100644 fastdeploy/inter_communicator/ipc_signal.py create mode 100644 fastdeploy/inter_communicator/zmq_client.py create mode 100644 fastdeploy/model_executor/guided_decoding/__init__.py create mode 100644 fastdeploy/model_executor/guided_decoding/base_guided_decoding.py create mode 100644 fastdeploy/model_executor/guided_decoding/ernie_tokenizer.py create mode 100644 fastdeploy/model_executor/guided_decoding/xgrammar_backend.py delete mode 100644 fastdeploy/model_executor/layers/attention/base.py rename custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit_kernels/__init__.py => fastdeploy/model_executor/layers/attention/ops/init_signal_layerwise.py (52%) create mode 100644 fastdeploy/model_executor/layers/attention/ops/open_shm_and_get_meta_signal.py create mode 100644 fastdeploy/model_executor/layers/attention/xpu_attn_backend.py rename fastdeploy/model_executor/{eplb => layers/backends/xpu/quantization}/__init__.py (83%) delete mode 100644 fastdeploy/model_executor/layers/moe/cutlass_fused_moe.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py delete mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_method_base.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py create mode 100644 fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py delete mode 100644 fastdeploy/model_executor/layers/moe/mm.py delete mode 100644 fastdeploy/model_executor/layers/moe/tp.py create mode 100644 fastdeploy/model_executor/layers/moe/triton_moe_kernels.py rename fastdeploy/model_executor/layers/quantization/{block_wise.py => block_wise_fp8.py} (52%) create mode 100644 fastdeploy/model_executor/layers/quantization/mix_quant.py rename {custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/jit => fastdeploy/model_executor/layers/quantization/ops}/__init__.py (65%) create mode 100644 fastdeploy/model_executor/layers/quantization/ops/cutlass_scaled_mm.py create mode 100644 fastdeploy/model_executor/layers/quantization/ops/scaled_fp8_quant.py create mode 100644 fastdeploy/model_executor/layers/quantization/tensor_wise_fp8.py create mode 100644 fastdeploy/model_executor/layers/quantization/w4a8.py create mode 100644 fastdeploy/model_executor/layers/quantization/wint2.py create mode 100644 fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_moe.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_mtp.py rename fastdeploy/{worker/V1 => model_executor/models/ernie4_5_vl}/__init__.py (100%) create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/configuration.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/__init__.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/activation.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/configuration.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/dist_utils.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py create mode 100644 fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py delete mode 100644 fastdeploy/model_executor/models/export_model.py create mode 100644 fastdeploy/model_executor/models/qwen3.py create mode 100644 fastdeploy/model_executor/models/qwen3moe.py delete mode 100644 fastdeploy/model_executor/models/tokenizer.py create mode 100644 fastdeploy/model_executor/ops/triton_ops/__init__.py create mode 100644 fastdeploy/model_executor/ops/triton_ops/triton_utils.py create mode 100644 fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py create mode 100644 fastdeploy/reasoning/__init__.py create mode 100644 fastdeploy/reasoning/abs_reasoning_parsers.py create mode 100644 fastdeploy/reasoning/ernie_vl_reasoning_parsers.py create mode 100644 fastdeploy/reasoning/qwen3_reasoning_parsers.py create mode 100644 fastdeploy/scheduler/splitwise_scheduler.py create mode 100644 fastdeploy/scheduler/utils.py create mode 100644 fastdeploy/spec_decode/__init__.py create mode 100644 fastdeploy/spec_decode/base.py create mode 100644 fastdeploy/spec_decode/mtp.py create mode 100644 fastdeploy/spec_decode/ngram.py rename fastdeploy/{worker/V1/xpu_worker.py => splitwise/__init__.py} (99%) create mode 100644 fastdeploy/splitwise/splitwise_connector.py create mode 100644 fastdeploy/start_splitwise.sh create mode 100644 fastdeploy/stop.sh create mode 100644 fastdeploy/test.yaml delete mode 100644 fastdeploy/worker/V1/gpu_model_runner.py delete mode 100644 fastdeploy/worker/V1/worker_process.py rename fastdeploy/{model_executor/eplb => worker}/eplb.py (67%) rename fastdeploy/{model_executor/eplb => worker}/experts_manager.py (66%) rename fastdeploy/worker/{model_runner => }/forward_meta.py (66%) create mode 100644 fastdeploy/worker/gpu_model_runner.py rename fastdeploy/worker/{V1 => }/gpu_worker.py (57%) delete mode 100644 fastdeploy/worker/model_runner/model_runner_inference.py delete mode 100644 fastdeploy/worker/model_runner/model_runner_minimal_os.py delete mode 100644 fastdeploy/worker/model_runner/model_runner_paddlenlp.py rename fastdeploy/worker/{V1 => }/model_runner_base.py (60%) create mode 100644 fastdeploy/worker/vl_gpu_model_runner.py rename fastdeploy/worker/{model_runner/model_runner_base.py => vl_model_runner_base.py} (56%) rename fastdeploy/worker/{worker.py => vl_worker_process.py} (62%) rename fastdeploy/worker/{V1 => }/worker_base.py (75%) create mode 100644 fastdeploy/worker/worker_process.py create mode 100644 fastdeploy/worker/xpu_model_runner.py create mode 100644 fastdeploy/worker/xpu_worker.py create mode 100644 mkdocs.yml delete mode 100644 scripts/codestyle/clang-tidy.py delete mode 100644 scripts/codestyle/clang_format.sh delete mode 100644 scripts/codestyle/copyright.py delete mode 100755 scripts/codestyle/pre_commit.sh create mode 100644 scripts/convert_ep_to_safetensor.py create mode 100644 scripts/extract_mtp_weight_from_safetensor.py delete mode 100644 scripts/prefill_fake_server.sh delete mode 100644 scripts/run_ci.sh delete mode 100644 scripts/run_offline_quantization.sh create mode 100644 scripts/run_prediction_ep_decoder_multi_node.sh create mode 100644 scripts/run_prediction_ep_decoder_multi_node_perf.sh create mode 100644 scripts/run_prediction_ep_decoder_single_node_perf.sh rename scripts/{codestyle/cpplint_pre_commit.sh => run_prediction_ep_prefill_perf.sh} (79%) create mode 100644 scripts/vit_model_split.py create mode 100644 scripts/vit_model_split.sh delete mode 100644 test/ci_use/test_qwen2_offline.py delete mode 100644 test/ci_use/test_qwen2_serving.py create mode 100644 test/layers/test_append_attention.py create mode 100644 test/operators/test_cutlass_scaled_mm.py create mode 100644 test/operators/test_rejection_top_p_sampling.py delete mode 100644 tools/dockerfile/Dockerfile.ci diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..3bb927623 --- /dev/null +++ b/.clang-format @@ -0,0 +1,29 @@ +# This file is used by clang-format to autoformat paddle source code +# +# The clang-format is part of llvm toolchain. +# It need to install llvm and clang to format source code style. +# +# The basic usage is, +# clang-format -i -style=file PATH/TO/SOURCE/CODE +# +# The -style=file implicit use ".clang-format" file located in one of +# parent directory. +# The -i means inplace change. +# +# The document of clang-format is +# http://clang.llvm.org/docs/ClangFormat.html +# http://clang.llvm.org/docs/ClangFormatStyleOptions.html +--- +Language: Cpp +BasedOnStyle: Google +IndentWidth: 4 +TabWidth: 2 +ContinuationIndentWidth: 4 +AccessModifierOffset: -1 # The private/protected/public has no indent in class +Standard: Cpp11 +AllowAllParametersOfDeclarationOnNextLine: true +BinPackParameters: false +BinPackArguments: false +IncludeBlocks: Preserve +IncludeIsMainSourceRegex: (\.cu)$ +... diff --git a/.gitignore b/.gitignore index 35c771cf5..f94e8f7cc 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,7 @@ dmypy.json FETCH_HEAD #log -log/ +log*/ checkpoints/ checkpoints_origin/ @@ -158,3 +158,7 @@ custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute # buff custom_ops/tmp* + +build + +.ccls-cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b08b23db..faa05efbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: rev: v0.11.7 hooks: - id: ruff - args: [--output-format, github, --fix] + args: [--output-format, github, --fix, --line-length=120] # # 拼写检查 # - repo: https://github.com/codespell-project/codespell # rev: v2.4.1 @@ -29,14 +29,15 @@ repos: rev: 6.0.1 hooks: - id: isort -# 格式化 -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v20.1.3 - hooks: - - id: clang-format - # exclude: '.*' - types_or: [c++, cuda] - args: [--style=file, --verbose] +# # 格式化 +# - repo: https://github.com/pre-commit/mirrors-clang-format +# rev: v20.1.3 +# hooks: +# - id: clang-format +# # exclude: '.*' +# types_or: [c++, cuda] +# args: [--style=file, --verbose] + # markdown - repo: https://github.com/jackdewinter/pymarkdown rev: v0.9.29 diff --git a/README.md b/README.md index 86ebda86d..55963d04d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ -# FastDeploy 2.0: 大模型推理部署 -

- - - + +

+

+ @@ -11,105 +10,78 @@

-FastDeploy升级2.0版本支持多种大模型推理(当前仅支持Qwen2,更多模型即将更新支持),其推理部署功能涵盖: +

+ Installation + | + Quick Start + | + Supported Models +

-- 一行命令即可快速实现模型的服务化部署,并支持流式生成 -- 利用张量并行技术加速模型推理 -- 支持 PagedAttention 与 continuous batching(动态批处理) -- 兼容 OpenAI 的 HTTP 协议 -- 提供 Weight only int8/int4 无损压缩方案 -- 支持 Prometheus Metrics 指标 +-------------------------------------------------------------------------------- +# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle -> 注意: 如果你还在使用FastDeploy部署小模型(如PaddleClas/PaddleOCR等CV套件模型),请checkout [release/1.1.0分支](https://github.com/PaddlePaddle/FastDeploy/tree/release/1.1.0)。 +## News -## 环境依赖 -- A800/H800/H100 -- Python>=3.10 -- CUDA>=12.3 -- CUDNN>=9.5 -- Linux X64 +**[2025-06] 🔥 Released FastDeploy v2.0:** Supports inference and deployment for ERNIE 4.5. Furthermore, we open-source an industrial-grade PD disaggregation with context caching, dynamic role switching for effective resource utilization to further enhance inference performance for MoE models. -## 安装 +## About -### Docker安装(推荐) -``` -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy:2.0.0.0-alpha -``` +**FastDeploy** is an inference and deployment toolkit for large language models and visual language models based on PaddlePaddle. It delivers **production-ready, out-of-the-box deployment solutions** with core acceleration technologies: -### 源码安装 -#### 安装PaddlePaddle -> 注意安装nightly build版本,代码版本需新于2025.05.30,详见[PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html),指定安装CUDA 12.6 develop(Nightly build)版本。 -``` -python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ -``` +- 🚀 **Load-Balanced PD Disaggregation**: Industrial-grade solution featuring context caching and dynamic instance role switching. Optimizes resource utilization while balancing SLO compliance and throughput. +- 🔄 **Unified KV Cache Transmission**: Lightweight high-performance transport library with intelligent NVLink/RDMA selection. +- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility. +- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more. +- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill. +- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc. -#### 编译安装FastDeploy +## Requirements -``` -# 编译 -cd FastDeploy -bash build.sh -# 安装 -pip install dist/fastdeploy-2.0.0a0-py3-none-any.whl -``` +- OS: Linux +- Python: 3.10 ~ 3.12 -## 快速使用 +## Installation -在安装后,执行如下命令快速部署Qwen2模型, 更多参数的配置与含义参考[参数说明](docs/serving.md). +FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions: -``` shell -# 下载与解压Qwen模型 -wget https://fastdeploy.bj.bcebos.com/llm/models/Qwen2-7B-Instruct.tar.gz && tar xvf Qwen2-7B-Instruct.tar.gz -# 指定单卡部署 -python -m fastdeploy.entrypoints.openai.api_server --model ./Qwen2-7B-Instruct --port 8188 --tensor-parallel-size 1 -``` +- [NVIDIA GPU](./docs/installation/nvidia_cuda.md) +- [Kunlunxin XPU](./docs/en/get_started/installation/kunlunxin_xpu.md) +- [Iluvatar GPU](./docs/en/get_started/installation/iluvatar_gpu.md) +- [Enflame GCU](./docs/en/get_started/installation/Enflame_gcu.md) -使用如下命令请求模型服务 -``` shell -curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ --H "Content-Type: application/json" \ --d '{ - "messages": [ - {"role": "user", "content": "你好,你的名字是什么?"} - ] -}' -``` -响应结果如下所示 -``` json -{ - "id": "chatcmpl-db662f47-7c8c-4945-9a7a-db563b2ddd8d", - "object": "chat.completion", - "created": 1749451045, - "model": "default", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "你好!我叫通义千问。", - "reasoning_content": null - }, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 25, - "total_tokens": 35, - "completion_tokens": 10, - "prompt_tokens_details": null - } -} -``` -FastDeploy提供与OpenAI完全兼容的服务API(字段`model`与`api_key`目前不支持,设定会被忽略),用户也可基于openai python api请求服务。 +**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates! -## 部署文档 -- [本地部署](docs/offline_inference.md) -- [服务部署](docs/serving.md) -- [服务metrics](docs/metrics.md) +## Get Started -# 代码说明 -- [代码目录说明](docs/code_guide.md) -- FastDeploy的使用中存在任何建议和问题,欢迎通过issue反馈。 +Learn how to use FastDeploy through our documentation: +- [10-Minutes Quick Deployment](./docs/get_started/quick_start.md) +- [ERNIE-4.5 Large Language Model Deployment](./docs/get_started/ernie-4.5.md) +- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md) +- [Offline Inference Development](./docs/offline_inference.md) +- [Online Service Deployment](./docs/serving/README.md) +- [Full Supported Models List](./docs/supported_models.md) -# 开源说明 -FastDeploy遵循[Apache-2.0开源协议](./LICENSE)。 在本项目的开发中,为了对齐[vLLM](https://github.com/vllm-project/vllm)使用接口,参考和直接使用了部分vLLM代码,在此表示感谢。 +## Supported Models + +| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length | +|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- | +|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅(WINT4/W4A8C8/Expert Parallelism)| ✅ | ✅|✅(WINT4)| WIP |128K | +|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅(WINT4/Expert Parallelism)| ✅ | ✅|✅(WINT4)| ❌ | 128K | +|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K | +|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K | +|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K | +|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K | +|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K | + +## Advanced Usage + +- [Quantization](./docs/quantization/README.md) +- [PD Disaggregation Deployment](./docs/features/pd_disaggregation.md) +- [Speculative Decoding](./docs/features/speculative_decoding.md) +- [Prefix Caching](./docs/features/prefix_caching.md) +- [Chunked Prefill](./docs/features/chunked_prefill.md) + +## Acknowledgement + +FastDeploy is licensed under the [Apache-2.0 open-source license](./LICENSE). During development, portions of [vLLM](https://github.com/vllm-project/vllm) code were referenced and incorporated to maintain interface compatibility, for which we express our gratitude. diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..d7a7e5007 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,106 @@ +### FastDeploy服务化性能压测工具 + +#### 数据集: + +wget下载到本地用于性能测试 + + + + + + + + + + + + + + +
DatasetData Path
开源数据集 2k条https://fastdeploy.bj.bcebos.com/eb_query/filtered_sharedgpt_2000_input_1136_output_200_fd.json
+#### 使用方式: + +``` +# 安装依赖 +python -m pip install -r requirements.txt +``` + +##### 参数说明 + +```bash +--backend openai-chat:压测使用的后端接口,指定为"openai-chat"使用chat/completion接口 +--model EB45T:模型名,任意取名,影响最后保存的结果文件名 EB45T \ +--endpoint /v1/chat/completions:endpoint,用于组url +--host 0.0.0.0:服务ip地址,用于组url +--port 9812:服务HTTP端口,用于组url +--dataset-name EBChat:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集 +--dataset-path ./eb45t_spv4_dataserver_1w_waigua_fd:压测数据集路径 +--hyperparameter-path EB45T.yaml:(可选)超参文件,请求时会更新进payload中,默认不带任何超参 +--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len:性能结果中展示的指标集合 +--metric-percentiles 80,95,99,99.9,99.95,99.99:性能结果中展示的性能指标分位值 +--num-prompts 1:总计发送多少条请求 +--max-concurrency 1:压测并发数 +--save-result:开启结果保存,结果文件会存入json +``` + +##### /v1/chat/completions接口压测单条数据调试 + +``` +python benchmark_serving.py \ + --backend openai-chat \ + --model EB45T \ + --endpoint /v1/chat/completions \ + --host 0.0.0.0 \ + --port 9812 \ + --dataset-name EBChat \ + --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ + --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ + --metric-percentiles 80,95,99,99.9,99.95,99.99 \ + --num-prompts 1 \ + --max-concurrency 1 \ + --save-result +``` + +##### /v1/chat/completions接口完整100并发 2000条压测 + +``` +# 保存infer_log.txt +python benchmark_serving.py \ + --backend openai-chat \ + --model EB45T \ + --endpoint /v1/chat/completions \ + --host 0.0.0.0 \ + --port 9812 \ + --dataset-name EBChat \ + --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ + --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ + --metric-percentiles 80,95,99,99.9,99.95,99.99 \ + --num-prompts 2000 \ + --max-concurrency 100 \ + --save-result > infer_log.txt 2>&1 & +``` + +##### /v1/completions接口压测 + +修改endpoint为/v1/completions,backend为openai,会对/v1/completions接口进行压测 + +``` +# 保存infer_log.txt +python benchmark_serving.py \ + --backend openai \ + --model EB45T \ + --endpoint /v1/completions \ + --host 0.0.0.0 \ + --port 9812 \ + --dataset-name EBChat \ + --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ + --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ + --metric-percentiles 80,95,99,99.9,99.95,99.99 \ + --num-prompts 2000 \ + --max-concurrency 100 \ + --save-result > infer_log.txt 2>&1 & +``` + diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py new file mode 100644 index 000000000..84b11d7a9 --- /dev/null +++ b/benchmarks/backend_request_func.py @@ -0,0 +1,700 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py + + +import io +import json +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + """Input for requesting LLMs via API""" + prompt: str + history_QA: Optional[dict] + hyper_parameters: dict + api_url: str + prompt_len: int + output_len: int + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False + language: Optional[str] = None + + +@dataclass +class RequestFuncOutput: + """Output for requesting LLMs via API""" + generated_text: str = "" + reasoning_content: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + arrival_time: list = field(default_factory=list) # arrival_time + itl: list = field(default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + prompt_tokens: int = 0 # 推理侧返回输入token数 + error: str = "" + + +async def async_request_eb_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using EB OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Chat Completions API URL must end with 'completions'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": "default", + "messages": request_func_input.history_QA, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True + }, + } + # 超参由yaml传入 + payload.update(request_func_input.hyper_parameters) + + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = 0 + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, type(chunk)) + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + reason_content = choices[0]["delta"].get("reasoning_content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + # cached_tokens + output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"] + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + output.generated_text += content or "" + output.reasoning_content += reason_content or "" + output.arrival_time.append(choices[0].get("arrival_time")) + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + output.prompt_tokens = usage.get( + "prompt_tokens") + + most_recent_timestamp = timestamp + + # output.generated_text = generated_text + if output.generated_text.strip() == "": + output.success = False + output.error = "No generated text found!" + else: + output.success = True + output.latency = most_recent_timestamp - st + else: + error_text = await response.text() + print("####error response:", error_text, "####payload:", payload) + output.error = error_text or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + # 保存失败请求结果 + if not output.success: + with open("error_output.txt", "a") as f: + f.write(str(output) + "\n") + if pbar: + pbar.update(1) + return output + + +async def async_request_eb_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using EB OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": "default", + "prompt": request_func_input.prompt, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True + }, + } + # 超参由yaml传入 + payload.update(request_func_input.hyper_parameters) + + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, chunk.usage) + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + output.arrival_time.append(choices[0].get("arrival_time")) + generated_text += text or "" + elif usage := data.get("usage"): + output.prompt_tokens = usage.get( + "prompt_tokens") + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_tgi( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using the TGI API""" + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + params = { + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, + "ignore_eos_token": request_func_input.ignore_eos, + } + payload = { + "inputs": request_func_input.prompt, + "parameters": params, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + if request_func_input.ignore_eos: + output.output_tokens = request_func_input.output_len + else: + output.output_tokens = None + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + + # NOTE: Sometimes TGI returns a ping response without + # any data, we should skip it. + if chunk_bytes.startswith(":"): + continue + chunk = chunk_bytes.removeprefix("data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + output.arrival_time.append(data["arrival_time"]) + + output.latency = most_recent_timestamp - st + output.success = True + output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using TRT's llm_server""" + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_deepspeed_mii( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using Deepspeed MII""" + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + + payload = { + "prompt": request_func_input.prompt, + "max_tokens": request_func_input.output_len, + "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. + "top_p": 1.0, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024, + # will use 0 as placeholder. + # See https://github.com/microsoft/DeepSpeed-MII/pull/311 + output.ttft = 0 + + st = time.perf_counter() + try: + async with session.post(url=request_func_input.api_url, + json=payload) as response: + if response.status == 200: + parsed_resp = await response.json() + output.latency = time.perf_counter() - st + if "choices" in parsed_resp: + output.generated_text = parsed_resp["choices"][0][ + "text"] + elif "text" in parsed_resp: + output.generated_text = parsed_resp["text"][0] + else: + output.error = ("Unexpected response format: " + "neither 'choices' nor 'text' found") + output.success = False + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using OpenAI""" + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "prompt": request_func_input.prompt, + # "temperature": 0.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + #"stream_options": { + # "include_usage": True, + #}, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, type(chunk)) + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Request an LLM using OpenAI""" + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + api_url = request_func_input.api_url + assert api_url.endswith( + ("transcriptions", "translations" + )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + form = aiohttp.FormData() + form.add_field('file', f, content_type='audio/wav') + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +ASYNC_REQUEST_FUNCS = { + "tgi": async_request_tgi, + "vllm": async_request_openai_completions, + "lmdeploy": async_request_openai_completions, + "deepspeed-mii": async_request_deepspeed_mii, + "openai": async_request_eb_openai_completions, + "openai-chat": async_request_eb_openai_chat_completions, + "openai-audio": async_request_openai_audio, + "tensorrt-llm": async_request_trt_llm, + "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, + async_request_eb_openai_chat_completions) +] + diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py new file mode 100644 index 000000000..2d8bcca34 --- /dev/null +++ b/benchmarks/benchmark_dataset.py @@ -0,0 +1,309 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_dataset.py + + +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from io import BytesIO +from typing import Any, Callable, Optional, Union +from PIL import Image + + +logger = logging.getLogger(__name__) + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + history_QA: Union[str, Any] + json_data: Optional[dict] + prompt_len: int + expected_output_len: int + + +class BenchmarkDataset(ABC): + """BenchmarkDataset""" + DEFAULT_SEED = 0 + IS_MULTIMODAL = False + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + hyperparameter_path: Optional[str] = None, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + self.hyperparameter_path = hyperparameter_path + self.hyperparameters = {} + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + @abstractmethod + def sample(self, num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = image.convert("RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +class EBDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + temperature: float + repetition_penalty: float + frequency_penalty: float + presence_penalty: float + top_p: float + prompt_len: int + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = [json.loads(i.strip()) for i in f.readlines()] + + def sample( + self, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt = entry["text"] + self.temperature = float(entry["temperature"]) + self.repetition_penalty = float(entry["penalty_score"]) + self.frequency_penalty = float(entry["frequency_score"]) + self.presence_penalty = float(entry["presence_score"]) + self.top_p = float(entry["topp"]) + self.prompt_len = int(entry["input_token_num"]) + new_output_len = int(entry["max_dec_len"]) + + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=self.prompt_len, + history_QA=[], + expected_output_len=new_output_len, + )) + + self.maybe_oversample_requests(samples, num_requests) + return samples + + +class EBChatDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + prompt_len: int + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = [json.loads(i.strip()) for i in f.readlines()] + + def sample( + self, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + json_data = entry + prompt = entry["messages"][-1].get("content", "") + history_QA = entry.get("messages", []) + new_output_len = int(entry.get("max_tokens", 12288)) + + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + json_data=json_data, + prompt=prompt, + prompt_len=0, + history_QA=history_QA, + expected_output_len=new_output_len, + )) + + self.maybe_oversample_requests(samples, num_requests) + return samples + diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py new file mode 100644 index 000000000..924f96ad4 --- /dev/null +++ b/benchmarks/benchmark_serving.py @@ -0,0 +1,1141 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py + + +import argparse +import asyncio +import gc +import json +import os +import random +import time +import warnings +import yaml +from collections.abc import AsyncGenerator, Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +import numpy as np +from backend_request_func import (ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, + RequestFuncOutput) +from tqdm.asyncio import tqdm + +from argparse import ArgumentParser as FlexibleArgumentParser + +from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + """Class containing all metrics that are used in this script""" + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_s_decode: float + median_s_decode: float + std_s_decode: float + percentiles_s_decode: list[tuple[float, float]] + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_s_ttft_ms: float + median_s_ttft_ms: float + std_s_ttft_ms: float + percentiles_s_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + mean_s_itl_ms: float + median_s_itl_ms: float + std_s_itl_ms: float + percentiles_s_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + mean_s_e2el_ms: float + median_s_e2el_ms: float + std_s_e2el_ms: float + percentiles_s_e2el_ms: list[tuple[float, float]] + mean_input_len: float + median_input_len: float + std_input_len: float + percentiles_input_len: list[tuple[float, float]] + mean_s_input_len: float + median_s_input_len: float + std_s_input_len: float + percentiles_s_input_len: list[tuple[float, float]] + mean_output_len: float + median_output_len: float + std_output_len: float + percentiles_output_len: list[tuple[float, float]] + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[SampleRequest, None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a SampleRequest. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests: Iterable[SampleRequest] = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculates various performance metrics based on the inputs and outputs.""" + input_lens: list[int] = [] + infer_input_lens: list[int] = [] # 推理侧输入token数 + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + s_itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + s_ttfts: list[float] = [] + e2els: list[float] = [] + s_e2els: list[float] = [] + s_decodes: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + print("no output_len") + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + + actual_output_lens.append(output_len) + input_lens.append(outputs[i].prompt_len) + infer_input_lens.append(outputs[i].prompt_tokens) + total_input += outputs[i].prompt_tokens + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + # 推理侧ITL + s_a = outputs[i].arrival_time[1:] + for j in range(len(s_a) - 2): + s_itls.append(s_a[j + 1] - s_a[j]) + ttfts.append(outputs[i].ttft) + # 推理侧TTFT + s_ttfts.append(outputs[i].arrival_time[1]) + e2els.append(outputs[i].latency) + # 推理侧整句时延 + s_e2els.append(outputs[i].arrival_time[-1]) + # 解码速度去掉首token + if len(outputs[i].arrival_time) > 2: + s_decodes.append((outputs[i].output_tokens - 1) / + (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])) + completed += 1 + else: + actual_output_lens.append(0) + input_lens.append(0) + infer_input_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_s_decode=np.mean(s_decodes or 0) * + 1, # ttfts is empty if streaming is not supported by backend + std_s_decode=np.std(s_decodes or 0) * 1, + median_s_decode=np.median(s_decodes or 0) * 1, + percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) + for p in selected_percentiles], + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_s_ttft_ms=np.mean(s_ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + std_s_ttft_ms=np.std(s_ttfts or 0) * 1000, + median_s_ttft_ms=np.median(s_ttfts or 0) * 1000, + percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_s_itl_ms=np.mean(s_itls or 0) * 1000, + std_s_itl_ms=np.std(s_itls or 0) * 1000, + median_s_itl_ms=np.median(s_itls or 0) * 1000, + percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000, + std_s_e2el_ms=np.std(s_e2els or 0) * 1000, + median_s_e2el_ms=np.median(s_e2els or 0) * 1000, + percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) + for p in selected_percentiles], + mean_input_len=np.mean(input_lens or 0) * 1, + std_input_len=np.std(input_lens or 0) * 1, + median_input_len=np.median(input_lens or 0) * 1, + percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) + for p in selected_percentiles], + mean_s_input_len=np.mean(infer_input_lens or 0) * 1, + std_s_input_len=np.std(infer_input_lens or 0) * 1, + median_s_input_len=np.median(infer_input_lens or 0) * 1, + percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) + for p in selected_percentiles], + mean_output_len=np.mean(actual_output_lens or 0) * 1, + std_output_len=np.std(actual_output_lens or 0) * 1, + median_output_len=np.median(actual_output_lens or 0) * 1, + percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + input_requests: list[SampleRequest], + hyper_parameters: dict, + logprobs: Optional[int], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], + extra_body: Optional[dict], +): + """Benchmarks an API endpoint using a given set of sample inputs and returns""" + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_output_len = \ + input_requests[0].prompt, \ + input_requests[0].expected_output_len + test_history_QA = input_requests[0].history_QA + + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + prompt_len=0, + history_QA=test_history_QA, + hyper_parameters=hyper_parameters, + api_url=api_url, + output_len=test_output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + + print("test_input:", test_input) + + test_output = await request_func(request_func_input=test_input) + + print("test_output:", test_output) + + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) \ + for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + output_len=test_output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + extra_body=extra_body) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, burstiness): + prompt, output_len = request.prompt, request.expected_output_len + history_QA = request.history_QA + + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + prompt_len=0, + history_QA=history_QA, + hyper_parameters=hyper_parameters, + api_url=api_url, + output_len=output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + extra_body=extra_body) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + # tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.3f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput:": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "infer_input_lens": [output.prompt_tokens for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "input_texts": [input.prompt for input in input_requests], + "generated_texts": [output.generated_text for output in outputs], + "reasoning_contents": [output.reasoning_content for output in outputs], + "errors": [output.error for output in outputs], + } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + def process_one_length( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name}:", + getattr(metrics, f"mean_{metric_attribute_name}"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name}:", + getattr(metrics, f"median_{metric_attribute_name}"))) + result[f"mean_{metric_attribute_name}"] = getattr( + metrics, f"mean_{metric_attribute_name}") + result[f"median_{metric_attribute_name}"] = getattr( + metrics, f"median_{metric_attribute_name}") + result[f"std_{metric_attribute_name}"] = getattr( + metrics, f"std_{metric_attribute_name}") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", + value)) + result[f"p{p_word}_{metric_attribute_name}"] = value + + process_one_length("s_decode", "Decode", "解码速度(tok/s)") + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency") + process_one_length("input_len", "Cached Tokens", "Cached Tokens") + process_one_length("s_input_len", "Input Length", "Infer Input Length") + process_one_length("output_len", "Output Length", "Output Length") + + print("=" * 50) + + return result + + +def check_goodput_args(args): + """Check whether the given argument has valid goodput configuration or not""" + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + """Parse the string into a dictionary with keys being names of SLOS and values being their corresponding values""" + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + """Save the benchmarking results to PyTorch Benchmark Format JSON file""" + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def main(args: argparse.Namespace): + """Main entry point""" + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + backend = args.backend + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "EB": + lambda: EBDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "EBChat": + lambda: EBChatDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + } + + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + + goodput_config_dict = check_goodput_args(args) + + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError( + "Sampling parameters are only supported by openai-compatible " + "backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + # 超参由yaml传入 + if args.hyperparameter_path: + with open(args.hyperparameter_path, "r") as f: + hyper_parameters = yaml.safe_load(f) + else: + hyper_parameters = {} + + benchmark_result = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + input_requests=input_requests, + hyper_parameters=hyper_parameters, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + )) + + # Save config and results to json + if args.save_result: + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["backend"] = backend + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", "output_lens", "ttfts", "itls", + "generated_texts", "errors" + ]: + if field in result_json: + del result_json[field] + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + # Save to file + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, "w", encoding='utf-8') as outfile: + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="sharegpt", + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument("--hyperparameter-path", + type=str, + default=None, + help="Path to the hyperparameter. ") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-separated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help="Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help="Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help="Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=("Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]."), + ) + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on openai-compatible " + "backends.") + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on openai-compatible " + "backends.") + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on openai-compatible " + "backends.") + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).") + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + args = parser.parse_args() + + main(args) + diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py new file mode 100644 index 000000000..6c149bf5f --- /dev/null +++ b/benchmarks/benchmark_utils.py @@ -0,0 +1,90 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_utils.py + + +import argparse +import json +import math +import os +from typing import Any + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + """InfEncoder""" + def clear_inf(self, o: Any): + """clear_inf""" + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + """iterencode""" + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + """write_to_json""" + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) + diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000..1ad085b79 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,5 @@ +aiohttp +tqdm +numpy +Pillow +pyyaml diff --git a/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml new file mode 100644 index 000000000..280f8e336 --- /dev/null +++ b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml @@ -0,0 +1,8 @@ +enable_chunked_prefill: True +max_model_len: 131072 +max_num_seqs: 16 +kv_cache_ratio: 0.75 +tensor_parallel_size: 8 +max_num_batched_tokens: 4096 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 diff --git a/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml new file mode 100644 index 000000000..d3aaa9243 --- /dev/null +++ b/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml @@ -0,0 +1,5 @@ +max_model_len: 131072 +max_num_seqs: 40 +gpu_memory_utilization: 0.9 +tensor_parallel_size: 8 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml new file mode 100644 index 000000000..280f8e336 --- /dev/null +++ b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml @@ -0,0 +1,8 @@ +enable_chunked_prefill: True +max_model_len: 131072 +max_num_seqs: 16 +kv_cache_ratio: 0.75 +tensor_parallel_size: 8 +max_num_batched_tokens: 4096 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 diff --git a/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml b/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml new file mode 100644 index 000000000..db8a20b86 --- /dev/null +++ b/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml @@ -0,0 +1,10 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 1 +enable_chunked_prefill: True +max_num_batched_tokens: 384 +quantization: wint4 +reasoning_parser: ernie-45-vl \ No newline at end of file diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml new file mode 100644 index 000000000..f57706607 --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml new file mode 100644 index 000000000..783a42c6b --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 32 +kv_cache_ratio: 0.5 +tensor_parallel_size: 1 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml new file mode 100644 index 000000000..366b4952e --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml new file mode 100644 index 000000000..b5add626e --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 +quantization: wint8 diff --git a/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml b/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml new file mode 100644 index 000000000..f57706607 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 diff --git a/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml b/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml new file mode 100644 index 000000000..b2f9a7457 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml @@ -0,0 +1,12 @@ +max_model_len: 32768 +max_num_seqs: 256 +tensor_parallel_size: 8 +quantization: block_wise_fp8 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +enable_chunked_prefill: True +max_num_batched_tokens: 1024 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 +enable_prefix_caching: True +swap_space: 200 diff --git a/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml b/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml new file mode 100644 index 000000000..47d1bfbcd --- /dev/null +++ b/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml @@ -0,0 +1,11 @@ +max_model_len: 32768 +max_num_seqs: 256 +tensor_parallel_size: 8 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +enable_chunked_prefill: True +max_num_batched_tokens: 1024 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 +enable_prefix_caching: True +swap_space: 200 diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml new file mode 100644 index 000000000..6ac9a2188 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml new file mode 100644 index 000000000..957f59d2a --- /dev/null +++ b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml @@ -0,0 +1,15 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 4 +cache_queue_port: 55663 +enable_chunked_prefill: True +splitwise_role: decode +engine_worker_queue_port: 6678 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7671,7672,7673,7674" +pd_comm_port: "2334" +max_num_batched_tokens: 384 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml new file mode 100644 index 000000000..c1466160d --- /dev/null +++ b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml @@ -0,0 +1,12 @@ +max_model_len: 32768 +max_num_seqs: 16 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.9 +tensor_parallel_size: 4 +splitwise_role: prefill +enable_prefix_caching: True +cache_queue_port: 55664 +engine_worker_queue_port: 6677 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7675,7676,7677,7678" +pd_comm_port: "2333" \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml b/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml new file mode 100644 index 000000000..af8d49e80 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_prefix_caching: true +enable_chunked_prefill: true diff --git a/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml new file mode 100644 index 000000000..6ac9a2188 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 diff --git a/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml new file mode 100644 index 000000000..2e00aad6d --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 1 +data_parallel_size: 8 +num_gpu_blocks_override: 1024 +cache_queue_port: 55663 +splitwise_role: decode +engine_worker_queue_port: 6678 +cache_transfer_protocol: "rdma" +rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678" +pd_comm_port: "2334" diff --git a/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml new file mode 100644 index 000000000..e6d0fa6e0 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 16 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.9 +tensor_parallel_size: 1 +data_parallel_size: 8 +splitwise_role: prefill +cache_queue_port: 55664 +engine_worker_queue_port: 6677 +num_gpu_blocks_override: 1024 +cache_transfer_protocol: "rdma" +rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678" +pd_comm_port: "2334" \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml new file mode 100644 index 000000000..c609fba49 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml new file mode 100644 index 000000000..e239cea89 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.7 +tensor_parallel_size: 4 +cache_queue_port: 55663 +enable_chunked_prefill: False +enable_prefix_caching: False +splitwise_role: decode +engine_worker_queue_port: 6678 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7671,7672,7673,7674" +pd_comm_port: "2334" \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml new file mode 100644 index 000000000..6d759c843 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml @@ -0,0 +1,12 @@ +max_model_len: 32768 +max_num_seqs: 16 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.9 +tensor_parallel_size: 4 +splitwise_role: prefill +enable_prefix_caching: False +cache_queue_port: 55664 +engine_worker_queue_port: 6677 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7675,7676,7677,7678" +pd_comm_port: "2333" \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml new file mode 100644 index 000000000..14f025dc0 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 40 +tensor_parallel_size: 4 +quantization: wint4 +gpu_memory_utilization: 0.9 diff --git a/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml new file mode 100644 index 000000000..b5059f185 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 160 +tensor_parallel_size: 8 +quantization: wint4 +gpu_memory_utilization: 0.9 diff --git a/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml new file mode 100644 index 000000000..5a5de2aba --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml @@ -0,0 +1,8 @@ +enable_prefix_caching: True +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 +swap_space: 200 +cache_queue_port: 55664 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml new file mode 100644 index 000000000..957f59d2a --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml @@ -0,0 +1,15 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 4 +cache_queue_port: 55663 +enable_chunked_prefill: True +splitwise_role: decode +engine_worker_queue_port: 6678 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7671,7672,7673,7674" +pd_comm_port: "2334" +max_num_batched_tokens: 384 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml new file mode 100644 index 000000000..c1466160d --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml @@ -0,0 +1,12 @@ +max_model_len: 32768 +max_num_seqs: 16 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.9 +tensor_parallel_size: 4 +splitwise_role: prefill +enable_prefix_caching: True +cache_queue_port: 55664 +engine_worker_queue_port: 6677 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7675,7676,7677,7678" +pd_comm_port: "2333" \ No newline at end of file diff --git a/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml new file mode 100644 index 000000000..a8a51c086 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 8 diff --git a/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml new file mode 100644 index 000000000..f1fde433f --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 80 +tensor_parallel_size: 8 +quantization: wint8 +gpu_memory_utilization: 0.9 diff --git a/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml new file mode 100644 index 000000000..e597f5bb7 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml @@ -0,0 +1,9 @@ +enable_prefix_caching: True +max_model_len: 32768 +max_num_batched_tokens: 68304 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 8 +swap_space: 100 +cache_queue_port: 55664 diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml new file mode 100644 index 000000000..1a53f9b9a --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml @@ -0,0 +1,9 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 56 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint4 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml new file mode 100644 index 000000000..31d3f5a14 --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml @@ -0,0 +1,11 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 56 +gpu_memory_utilization: 0.8 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint4 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +enable_chunked_prefill: True +max_num_batched_tokens: 384 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml new file mode 100644 index 000000000..9646a4c61 --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml @@ -0,0 +1,9 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 36 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 4 +quantization: wint4 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml new file mode 100644 index 000000000..3c803e662 --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml @@ -0,0 +1,9 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 36 +gpu_memory_utilization: 0.95 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint8 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml new file mode 100644 index 000000000..ff9611f5d --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml @@ -0,0 +1,11 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 36 +gpu_memory_utilization: 0.8 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint8 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +enable_chunked_prefill: True +max_num_batched_tokens: 384 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml new file mode 100644 index 000000000..e01db1566 --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml @@ -0,0 +1,9 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 36 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 4 +quantization: wint8 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml new file mode 100644 index 000000000..14024b565 --- /dev/null +++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint8 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml new file mode 100644 index 000000000..14024b565 --- /dev/null +++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint8 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml new file mode 100644 index 000000000..010dd3bc3 --- /dev/null +++ b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint4 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml new file mode 100644 index 000000000..eec95559d --- /dev/null +++ b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml new file mode 100644 index 000000000..c88178259 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml @@ -0,0 +1,4 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml new file mode 100644 index 000000000..8cdc10498 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wfp8afp8 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml new file mode 100644 index 000000000..d766c9f53 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wfp8afp8 diff --git a/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml new file mode 100644 index 000000000..90af4a558 --- /dev/null +++ b/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint8 diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml new file mode 100644 index 000000000..14024b565 --- /dev/null +++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint8 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml new file mode 100644 index 000000000..14024b565 --- /dev/null +++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint8 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml new file mode 100644 index 000000000..55a37e029 --- /dev/null +++ b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml new file mode 100644 index 000000000..010dd3bc3 --- /dev/null +++ b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +quantization: wint4 +enable_static_graph_inference: True diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml new file mode 100644 index 000000000..45ee7d14e --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml new file mode 100644 index 000000000..45ee7d14e --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml new file mode 100644 index 000000000..45ee7d14e --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml new file mode 100644 index 000000000..60a6dbeef --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +quantization: wint8 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml new file mode 100644 index 000000000..60a6dbeef --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +quantization: wint8 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml new file mode 100644 index 000000000..60a6dbeef --- /dev/null +++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +quantization: wint8 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml b/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml new file mode 100644 index 000000000..7a127995e --- /dev/null +++ b/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 75 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.75 +quantization: wint4 +tensor_parallel_size: 4 \ No newline at end of file diff --git a/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml b/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml new file mode 100644 index 000000000..4d6cff601 --- /dev/null +++ b/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 25 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.75 +quantization: wint8 +tensor_parallel_size: 4 \ No newline at end of file diff --git a/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml new file mode 100644 index 000000000..00fa7bef0 --- /dev/null +++ b/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 50 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml new file mode 100644 index 000000000..00fa7bef0 --- /dev/null +++ b/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 50 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml new file mode 100644 index 000000000..8ed7b40b3 --- /dev/null +++ b/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 50 +gpu_memory_utilization: 0.8 +kv_cache_ratio: 0.75 +quantization: wint4 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml new file mode 100644 index 000000000..8ed7b40b3 --- /dev/null +++ b/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 50 +gpu_memory_utilization: 0.8 +kv_cache_ratio: 0.75 +quantization: wint4 +tensor_parallel_size: 1 diff --git a/benchmarks/yaml/request_yaml/eb45-128k.yaml b/benchmarks/yaml/request_yaml/eb45-128k.yaml new file mode 100644 index 000000000..052d20997 --- /dev/null +++ b/benchmarks/yaml/request_yaml/eb45-128k.yaml @@ -0,0 +1,8 @@ +top_p: 0.8 +temperature: 0.8 +metadata: + min_tokens: 1 +max_tokens: 131071 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/request_yaml/eb45-32k.yaml b/benchmarks/yaml/request_yaml/eb45-32k.yaml new file mode 100644 index 000000000..07753d410 --- /dev/null +++ b/benchmarks/yaml/request_yaml/eb45-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.8 +temperature: 0.8 +metadata: + min_tokens: 1 +max_tokens: 12288 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/request_yaml/qwen2-32k.yaml b/benchmarks/yaml/request_yaml/qwen2-32k.yaml new file mode 100644 index 000000000..464277942 --- /dev/null +++ b/benchmarks/yaml/request_yaml/qwen2-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.8 +temperature: 0.7 +metadata: + min_tokens: 1 +max_tokens: 12288 +repetition_penalty: 1.05 +frequency_penalty: 0 +presence_penalty: 0 \ No newline at end of file diff --git a/benchmarks/yaml/request_yaml/qwen3-32k.yaml b/benchmarks/yaml/request_yaml/qwen3-32k.yaml new file mode 100644 index 000000000..8f1fc1fd7 --- /dev/null +++ b/benchmarks/yaml/request_yaml/qwen3-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.8 +temperature: 0.7 +metadata: + min_tokens: 1 +max_tokens: 12288 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 1.5 \ No newline at end of file diff --git a/benchmarks/yaml/request_yaml/x1-32k.yaml b/benchmarks/yaml/request_yaml/x1-32k.yaml new file mode 100644 index 000000000..7cec615c4 --- /dev/null +++ b/benchmarks/yaml/request_yaml/x1-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.95 +temperature: 0.6 +metadata: + min_tokens: 1 +max_tokens: 32767 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml new file mode 100644 index 000000000..b2cbce4a6 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml @@ -0,0 +1,6 @@ +tensor_parallel_size: 8 +max_model_len: 32768 +max_num_seqs: 32 +num_gpu_blocks_override: 4096 +kv_cache_ratio: 0.5 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml b/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml new file mode 100644 index 000000000..f6b593889 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 32 +gpu_memory_utilization: 0.9 +tensor_parallel_size: 4 +quantization: wint4 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml new file mode 100644 index 000000000..25a2e89a2 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +tensor_parallel_size: 8 +quantization: wint4 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml new file mode 100644 index 000000000..a6f522578 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml @@ -0,0 +1,10 @@ +enable_prefix_caching: True +num_gpu_blocks_override: 8000 +max_model_len: 32768 +max_num_seqs: 64 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.5 +tensor_parallel_size: 8 +swap_space: 200 +cache_queue_port: 55664 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml new file mode 100644 index 000000000..b2cbce4a6 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml @@ -0,0 +1,6 @@ +tensor_parallel_size: 8 +max_model_len: 32768 +max_num_seqs: 32 +num_gpu_blocks_override: 4096 +kv_cache_ratio: 0.5 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml b/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml new file mode 100644 index 000000000..df01844d1 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 8 +gpu_memory_utilization: 0.9 +tensor_parallel_size: 4 +quantization: wint8 +reasoning_parser: ernie-x1 diff --git a/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml new file mode 100644 index 000000000..376177602 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 64 +gpu_memory_utilization: 0.9 +tensor_parallel_size: 8 +quantization: wint8 +reasoning_parser: ernie-x1 \ No newline at end of file diff --git a/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml new file mode 100644 index 000000000..a6f522578 --- /dev/null +++ b/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml @@ -0,0 +1,10 @@ +enable_prefix_caching: True +num_gpu_blocks_override: 8000 +max_model_len: 32768 +max_num_seqs: 64 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.5 +tensor_parallel_size: 8 +swap_space: 200 +cache_queue_port: 55664 +reasoning_parser: ernie-x1 diff --git a/build.sh b/build.sh index 8591a52f2..4e4098559 100644 --- a/build.sh +++ b/build.sh @@ -17,8 +17,9 @@ BUILD_WHEEL=${1:-1} PYTHON_VERSION=${2:-"python"} export python=$PYTHON_VERSION -CPU_USE_BF16=${3:-"false"} -BUILDING_ARCS=${4:-""} +FD_CPU_USE_BF16=${3:-"false"} +FD_BUILDING_ARCS=${4:-""} + # paddle distributed use to set archs unset PADDLE_CUDA_ARCH_LIST @@ -30,13 +31,9 @@ EGG_DIR="fastdeploy.egg-info" # custom_ops directory config OPS_SRC_DIR="custom_ops" -OPS_BUILD_DIR="build" -OPS_EGG_DIR="efficitentllm_ops.egg-info" OPS_TMP_DIR_BASE="tmp_base" OPS_TMP_DIR="tmp" -TEST_DIR="tests" - # command line log config RED='\033[0;31m' BLUE='\033[0;34m' @@ -44,13 +41,14 @@ GREEN='\033[1;32m' BOLD='\033[1m' NONE='\033[0m' +DEVICE_TYPE="gpu" function python_version_check() { PY_MAIN_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'` PY_SUB_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $2}'` echo -e "find python version ${PY_MAIN_VERSION}.${PY_SUB_VERSION}" - if [ $PY_MAIN_VERSION -ne "3" -o $PY_SUB_VERSION -lt "8" ]; then - echo -e "${RED}FAIL:${NONE} please use Python >= 3.8" + if [ $PY_MAIN_VERSION -ne "3" -o $PY_SUB_VERSION -lt "9" ]; then + echo -e "${RED}FAIL:${NONE} please use Python >= 3.9" exit 1 fi } @@ -75,6 +73,7 @@ function copy_ops(){ WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"` if [ "$is_rocm" = "True" ]; then + DEVICE_TYPE="rocm" cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu echo -e "ROCM ops have been copy to fastdeploy" return @@ -82,6 +81,7 @@ function copy_ops(){ mkdir -p ../fastdeploy/model_executor/ops/base is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"` if [ "$is_cuda" = "True" ]; then + DEVICE_TYPE="gpu" cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu echo -e "BASE and CUDA ops have been copy to fastdeploy" @@ -90,6 +90,7 @@ function copy_ops(){ is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"` if [ "$is_xpu" = "True" ]; then + DEVICE_TYPE="xpu" cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/xpu echo -e "xpu ops have been copy to fastdeploy" return @@ -97,20 +98,14 @@ function copy_ops(){ is_npu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('npu'))"` if [ "$is_npu" = "True" ]; then + DEVICE_TYPE="npu" cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/npu echo -e "npu ops have been copy to fastdeploy" return fi + DEVICE_TYPE="cpu" cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base - cd ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/xFasterTransformer/build/ - for file in *_pd_.so; do - mv "$file" "${file/_pd_/}" - done - cd ../../x86-simd-sort/builddir/ - for file in *_pd_.so; do - mv "$file" "${file/_pd_/}" - done cd ../../../../ cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu echo -e "BASE and CPU ops have been copy to fastdeploy" @@ -122,15 +117,30 @@ function build_and_install_ops() { export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy} echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..." ${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE} + find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \; echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..." - if [ "$CPU_USE_BF16" == "true" ]; then - CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} - : - elif [ "$CPU_USE_BF16" == "false" ]; then + TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}` + is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"` + if [ "$is_xpu" = "True" ]; then + cd xpu_ops/src + bash build.sh ${TMP_DIR_REAL_PATH} + cd ../.. + elif [ "$FD_CPU_USE_BF16" == "true" ]; then + if [ "$FD_BUILDING_ARCS" == "" ]; then + FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + else + FD_BUILDING_ARCS=${FD_BUILDING_ARCS} FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + fi + find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; + elif [ "$FD_CPU_USE_BF16" == "false" ]; then + if [ "$FD_BUILDING_ARCS" == "" ]; then ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} - : + else + FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + fi + find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; else - echo "Error: Invalid parameter '$CPU_USE_BF16'. Please use true or false." + echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false." exit 1 fi if [ $? -ne 0 ]; then @@ -146,11 +156,7 @@ function build_and_install_ops() { function build_and_install() { echo -e "${BLUE}[build]${NONE} building fastdeploy wheel..." - if [ "$BUILDING_ARCS" == "" ]; then - ${python} setup.py bdist_wheel --python-tag py3 - else - BUILDING_ARCS=${BUILDING_ARCS} ${python} setup.py bdist_wheel --python-tag py3 - fi + ${python} setup.py bdist_wheel --python-tag=py3 if [ $? -ne 0 ]; then echo -e "${RED}[FAIL]${NONE} build fastdeploy wheel failed" @@ -174,10 +180,12 @@ function cleanup() { rm -rf $BUILD_DIR $EGG_DIR if [ `${python} -m pip list | grep fastdeploy | wc -l` -gt 0 ]; then echo -e "${BLUE}[init]${NONE} uninstalling fastdeploy..." - ${python} -m pip uninstall -y fastdeploy + ${python} -m pip uninstall -y fastdeploy-${DEVICE_TYPE} fi rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR + rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE + rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR } function abort() { @@ -187,7 +195,7 @@ function abort() { cur_dir=`basename "$pwd"` rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR - ${python} -m pip uninstall -y fastdeploy + ${python} -m pip uninstall -y fastdeploy-${DEVICE_TYPE} rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR } diff --git a/custom_ops/0001-DeepGEMM-95e81b3.patch b/custom_ops/0001-DeepGEMM-95e81b3.patch new file mode 100644 index 000000000..e62972cec --- /dev/null +++ b/custom_ops/0001-DeepGEMM-95e81b3.patch @@ -0,0 +1,643 @@ +From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001 +From: minghaipeng +Date: Wed, 25 Jun 2025 15:05:24 +0800 +Subject: [PATCH] DeepGEMM 95e81b3 + +--- + deep_gemm/__init__.py | 2 +- + deep_gemm/include/deep_gemm/scheduler.cuh | 2 +- + deep_gemm/jit/compiler.py | 2 +- + deep_gemm/jit/interleave_ffma.py | 2 +- + deep_gemm/jit/runtime.py | 4 +- + deep_gemm/jit/template.py | 34 ++++---- + deep_gemm/jit_kernels/gemm.py | 44 +++++------ + deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------ + deep_gemm/jit_kernels/tuner.py | 10 +-- + deep_gemm/jit_kernels/utils.py | 18 +++-- + deep_gemm/paddle_utils.py | 20 +++++ + deep_gemm/utils.py | 30 +++---- + 12 files changed, 143 insertions(+), 121 deletions(-) + create mode 100644 deep_gemm/paddle_utils.py + +diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py +index 15b22ca..63e7fb7 100644 +--- a/deep_gemm/__init__.py ++++ b/deep_gemm/__init__.py +@@ -1,4 +1,4 @@ +-import torch ++import paddle + + from . import jit + from .jit_kernels import ( +diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh +index 9743871..6c97152 100644 +--- a/deep_gemm/include/deep_gemm/scheduler.cuh ++++ b/deep_gemm/include/deep_gemm/scheduler.cuh +@@ -102,7 +102,7 @@ struct Scheduler { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::GroupedContiguous) { +- auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M); ++ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)); + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::GroupedMasked) { + return curr_group_idx * shape_dim + block_idx * block_size; +diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py +index c17d466..6fdc52f 100644 +--- a/deep_gemm/jit/compiler.py ++++ b/deep_gemm/jit/compiler.py +@@ -4,7 +4,7 @@ import os + import re + import subprocess + import uuid +-from torch.utils.cpp_extension import CUDA_HOME ++from ..paddle_utils import CUDA_HOME + from typing import Tuple + + from . import interleave_ffma +diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py +index fcb377e..db9d6f3 100644 +--- a/deep_gemm/jit/interleave_ffma.py ++++ b/deep_gemm/jit/interleave_ffma.py +@@ -3,7 +3,7 @@ import mmap + import os + import re + import subprocess +-from torch.utils.cpp_extension import CUDA_HOME ++from ..paddle_utils import CUDA_HOME + + + def run_cuobjdump(file_path): +diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py +index 66c370a..4761426 100644 +--- a/deep_gemm/jit/runtime.py ++++ b/deep_gemm/jit/runtime.py +@@ -1,6 +1,6 @@ + import ctypes + import os +-import torch ++import paddle + from typing import Optional + + from .template import map_ctype +@@ -35,7 +35,7 @@ class Runtime: + assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}' + cargs = [] + for arg, (name, dtype) in zip(args, self.args): +- if isinstance(arg, torch.Tensor): ++ if isinstance(arg, paddle.Tensor): + assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`' + else: + assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`' +diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py +index ead37f5..51b02c1 100644 +--- a/deep_gemm/jit/template.py ++++ b/deep_gemm/jit/template.py +@@ -1,24 +1,24 @@ + import copy + import ctypes + import os +-import torch ++import paddle + from typing import Any, Dict, Iterable, Tuple + + + # Name map for Python `eval` + typename_map: Dict[Any, str] = { + **{t: t.__name__ for t in (bool, int, float)}, +- torch.int: 'torch.int', +- torch.float: 'torch.float', +- torch.bfloat16: 'torch.bfloat16', +- torch.float8_e4m3fn: 'torch.float8_e4m3fn', +- torch.cuda.Stream: 'torch.cuda.Stream', ++ paddle.int32: 'paddle.int32', ++ paddle.float32: 'paddle.float32', ++ paddle.bfloat16: 'paddle.bfloat16', ++ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn', ++ paddle.device.cuda.Stream: "paddle.device.cuda.Stream", + } + + # `ctype` map for Python casting + ctype_map: Dict[Any, Any] = { + **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)}, +- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)}, ++ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)}, + } + + +@@ -27,25 +27,25 @@ genc_map = { + bool: ('bool', 'bool'), + int: ('int', 'int'), + float: ('float', 'float'), +- torch.int: ('void*', 'int*'), +- torch.float: ('void*', 'float*'), +- torch.bfloat16: ('void*', '__nv_bfloat16*'), +- torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), +- torch.cuda.Stream: ('void*', 'cudaStream_t'), ++ paddle.int32: ('void*', 'int*'), ++ paddle.float32: ('void*', 'float*'), ++ paddle.bfloat16: ('void*', '__nv_bfloat16*'), ++ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'), ++ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'), + } + + + def map_ctype(value: Any) -> Any: + if hasattr(value, 'data_ptr'): +- if value.dtype == torch.int: ++ if value.dtype == paddle.int32: + return ctypes.c_void_p(value.data_ptr()) +- elif value.dtype == torch.float: ++ elif value.dtype == paddle.float32: + return ctypes.c_void_p(value.data_ptr()) +- elif value.dtype == torch.bfloat16: ++ elif value.dtype == paddle.bfloat16: + return ctypes.c_void_p(value.data_ptr()) +- elif value.dtype == torch.float16: ++ elif value.dtype == paddle.float16: + return ctypes.c_void_p(value.data_ptr()) +- elif value.dtype == torch.float8_e4m3fn: ++ elif value.dtype == paddle.float8_e4m3fn: + return ctypes.c_void_p(value.data_ptr()) + else: + return ctypes.c_void_p(value.data_ptr()) +diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py +index cb438b7..44aa0ed 100644 +--- a/deep_gemm/jit_kernels/gemm.py ++++ b/deep_gemm/jit_kernels/gemm.py +@@ -1,5 +1,5 @@ + import math +-import torch ++import paddle + from functools import lru_cache + from typing import Tuple + +@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, + return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config + + +-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], +- rhs: Tuple[torch.Tensor, torch.Tensor], +- out: torch.Tensor) -> None: ++def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor], ++ rhs: Tuple[paddle.Tensor, paddle.Tensor], ++ out: paddle.Tensor) -> None: + """ + Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, +- this function will do a transposing with a set of slow PyTorch operations. ++ this function will do a transposing with a set of slow paddle operations. + + Arguments: +- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, ++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. +- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. ++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[m, n]`, representing the result. + """ +@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + n, k_ = rhs.shape + m_, n_ = out.shape + +- assert n % 64 == 0 and k % 128 == 0 ++ # assert n % 64 == 0 and k % 128 == 0 + + # Type and shape checks +- assert m == m_ and n == n_ and k == k_ +- assert n > 0 and k > 0 +- assert lhs_scales.shape == (m, (k + 127) // 128) +- assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) +- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 +- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 +- assert out.dtype == torch.bfloat16 +- assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() ++ # assert m == m_ and n == n_ and k == k_ ++ # assert n > 0 and k > 0 ++ # assert lhs_scales.shape == (m, (k + 127) // 128) ++ # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) ++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 ++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 ++ # assert out.dtype == paddle.bfloat16 ++ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) +- assert rhs_scales.is_contiguous() ++ # assert rhs_scales.is_contiguous() + + # Do nothing if `m` is zero + if m == 0: +@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + global includes, template + num_sms = get_num_sms() + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms) +- args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0]) ++ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0]) + runtime = jit_tuner.compile_and_tune( + name='gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, +@@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, + space=(), + includes=includes, +- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), +- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), +- ('out', torch.bfloat16), ('m', int), +- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), ++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32), ++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32), ++ ('out', paddle.bfloat16), ('m', int), ++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) +diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py +index 3b518c9..ba776bd 100644 +--- a/deep_gemm/jit_kernels/m_grouped_gemm.py ++++ b/deep_gemm/jit_kernels/m_grouped_gemm.py +@@ -1,4 +1,4 @@ +-import torch ++import paddle + from typing import Tuple + + from .gemm import get_best_configs, get_block_n_padding_for_smem_d +@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout, + """ + + +-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], +- rhs: Tuple[torch.Tensor, torch.Tensor], +- out: torch.Tensor, m_indices: torch.Tensor) -> None: ++def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor], ++ rhs: Tuple[paddle.Tensor, paddle.Tensor], ++ out: paddle.Tensor, m_indices: paddle.Tensor) -> None: + """ + Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, +- this function will do a transposing with a set of slow PyTorch operations. ++ this function will do a transposing with a set of slow Pypaddle operations. + On the M axis, inputs are grouped into several batches, of which batch sizes aligned to + `get_m_alignment_for_contiguous_layout()` (128). + + Arguments: +- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`, ++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`. +- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. ++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[m_sum, n]`, representing the result. +- m_indices: a tensor of shape `[m_sum]` with type `torch.int`. ++ m_indices: a tensor of shape `[m_sum]` with type `paddle.int`. + `m_indices[i]` records the group which the i-th row of the LHS belong to, + which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. + Values of `m_indices` in every-m-alignment-block must also be the same. +@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + m__ = m_indices.numel() + + # Type and shape checks +- assert m == m_ == m__ and k == k_ and n == n_ +- assert lhs_scales.shape == (m, (k + 127) // 128) +- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) +- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 +- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 +- assert out.dtype == torch.bfloat16 +- assert m_indices.dtype == torch.int32 +- assert lhs.is_contiguous() and rhs.is_contiguous() +- assert out.is_contiguous() and m_indices.is_contiguous() ++ # assert m == m_ == m__ and k == k_ and n == n_ ++ # assert lhs_scales.shape == (m, (k + 127) // 128) ++ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) ++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 ++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 ++ # assert out.dtype == paddle.bfloat16 ++ # assert m_indices.dtype == paddle.int32 ++ # assert lhs.is_contiguous() and rhs.is_contiguous() ++ # assert out.is_contiguous() and m_indices.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) +- assert rhs_scales.is_contiguous() ++ # assert rhs_scales.is_contiguous() + + # Do nothing if `m` is zero + if m == 0: +@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) + args = (lhs, lhs_scales, rhs, rhs_scales, out, + m_indices, m, num_groups, +- torch.cuda.current_stream(), num_sms, smem_config[0]) ++ paddle.device.cuda.current_stream(), num_sms, smem_config[0]) + runtime = jit_tuner.compile_and_tune( + name='m_grouped_gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, +@@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + 'GEMM_TYPE': 'GroupedContiguous'}, + space=(), + includes=includes, +- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), +- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), +- ('out', torch.bfloat16), +- ('grouped_layout', torch.int32), ('m', int), ('num_groups', int), +- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), ++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32), ++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32), ++ ('out', paddle.bfloat16), ++ ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int), ++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) +@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + runtime(*args) + + +-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor], +- rhs: Tuple[torch.Tensor, torch.Tensor], +- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None: ++def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor], ++ rhs: Tuple[paddle.Tensor, paddle.Tensor], ++ out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None: + """ + Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling. + LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format. + RHS and RHS scaling factors are required to be transposed. + The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement, +- this function will do a transposing with a set of slow PyTorch operations. ++ this function will do a transposing with a set of slow paddle operations. + Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch + should be separately transposed. + + Arguments: +- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, ++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`, + the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`. +- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`. ++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`. + the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. + out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. + masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute +@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] + num_groups___ = masked_m.numel() + + # Type and shape checks +- assert num_groups == num_groups_ == num_groups__ == num_groups___ +- assert m == m_ and n == n_ and k == k_ +- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 +- assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) +- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) +- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 +- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 +- assert out.dtype == torch.bfloat16 +- assert masked_m.dtype == torch.int32 +- assert lhs.is_contiguous() and rhs.is_contiguous() +- assert out.is_contiguous() and masked_m.is_contiguous() ++ # assert num_groups == num_groups_ == num_groups__ == num_groups___ ++ # assert m == m_ and n == n_ and k == k_ ++ # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 ++ # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128) ++ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128) ++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32 ++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32 ++ # assert out.dtype == paddle.bfloat16 ++ # assert masked_m.dtype == paddle.int32 ++ # assert lhs.is_contiguous() and rhs.is_contiguous() ++ # assert out.is_contiguous() and masked_m.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) +- assert rhs_scales.is_contiguous() ++ # assert rhs_scales.is_contiguous() + + # Auto-tuning with compilation + global includes, template +@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] + + args = (lhs, lhs_scales, rhs, rhs_scales, out, + masked_m, m, +- torch.cuda.current_stream(), num_sms, smem_config[0]) ++ paddle.device.cuda.current_stream(), num_sms, smem_config[0]) + runtime = jit_tuner.compile_and_tune( + name='m_grouped_gemm_fp8_fp8_bf16_nt', + keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, +@@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] + 'GEMM_TYPE': 'GroupedMasked'}, + space=(), + includes=includes, +- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), +- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), +- ('out', torch.bfloat16), +- ('grouped_layout', torch.int32), ('m', int), +- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)), ++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32), ++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32), ++ ('out', paddle.bfloat16), ++ ('grouped_layout', paddle.int32), ('m', int), ++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)), + template=template, + args=args + ) +diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py +index 6ed6749..9e1d70f 100644 +--- a/deep_gemm/jit_kernels/tuner.py ++++ b/deep_gemm/jit_kernels/tuner.py +@@ -1,6 +1,6 @@ + import copy + import os +-import torch ++import paddle + from typing import Any, Dict + + from ..jit import build, cpp_format, generate, Runtime +@@ -51,10 +51,10 @@ class JITTuner: + continue + + # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels +- start_event = torch.cuda.Event(enable_timing=True) +- end_event = torch.cuda.Event(enable_timing=True) +- torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() +- torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') ++ start_event = paddle.device.cuda.Event(enable_timing=True) ++ end_event = paddle.device.cuda.Event(enable_timing=True) ++ paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_() ++ paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32) + start_event.record() + for i in range(20): + assert runtime(*args) == 0 +diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py +index c6da56b..a17b1b1 100644 +--- a/deep_gemm/jit_kernels/utils.py ++++ b/deep_gemm/jit_kernels/utils.py +@@ -1,4 +1,4 @@ +-import torch ++import paddle + + _num_sms = None + +@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None: + num_sms: the desired maximum SM count for all GEMM kernels to use. + """ + global _num_sms +- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count ++ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count + _num_sms = num_sms + + +@@ -25,7 +25,7 @@ def get_num_sms() -> int: + """ + global _num_sms + if _num_sms is None: +- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count ++ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count + return _num_sms + + +@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int: + return ceil_div(x, alignment) * alignment + + +-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: ++def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor: + """ +- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. ++ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along the M axis + (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. + +@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: +- if x.stride(0) == 1 and x.stride(1) == aligned_m: ++ if x.strides[0] == 1 and x.strides[1] == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout +- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: ++ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing +- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) ++ aligned_x = paddle.transpose( ++ paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1] ++ ) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x +diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py +new file mode 100644 +index 0000000..2326807 +--- /dev/null ++++ b/deep_gemm/paddle_utils.py +@@ -0,0 +1,20 @@ ++import os ++ ++def get_cuda_home(): ++ """Get Cuda home directory""" ++ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") ++ if cuda_home: ++ return cuda_home ++ ++ try: ++ which_cmd = "which nvcc" ++ ++ nvcc_path = os.popen(which_cmd).read().strip() ++ if nvcc_path: ++ return os.path.dirname(os.path.dirname(nvcc_path)) ++ except Exception: ++ pass ++ ++ return None ++ ++CUDA_HOME = get_cuda_home() +\ No newline at end of file +diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py +index d5cdd01..5237f09 100644 +--- a/deep_gemm/utils.py ++++ b/deep_gemm/utils.py +@@ -1,15 +1,15 @@ + import os + import sys + import time +-import torch +-import torch.distributed as dist ++import paddle ++import paddle.distributed as dist + + + def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data +- torch.cuda.synchronize() +- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') ++ paddle.device.cuda.synchronize() ++ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32) + cache.zero_() + + # Warmup +@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10, + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: +- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') +- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') ++ x = paddle.randn((8192, 8192), dtype=paddle.float32) ++ y = paddle.randn((8192, 8192), dtype=paddle.float32) + x @ y + + # Testing +- start_event = torch.cuda.Event(enable_timing=True) +- end_event = torch.cuda.Event(enable_timing=True) ++ start_event = paddle.device.cuda.Event(enable_timing=True) ++ end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() +- torch.cuda.synchronize() ++ paddle.device.synchronize() + + return start_event.elapsed_time(end_event) / num_tests + +@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): +- schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None +- profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() ++ scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None ++ profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: +- lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') +- rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') ++ lhs = paddle.randn((8192, 8192), dtype=paddle.float32) ++ rhs = paddle.randn((8192, 8192), dtype=paddle.float32) + lhs @ rhs +- dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) ++ dist.all_reduce(paddle.ones(1, dtype=paddle.float32)) + for _ in range(num_tests): + if sleep_between_tests > 0.0: + time.sleep(sleep_between_tests) + if flush_l2: +- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() ++ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_() + fn() + + if not using_nsys: +-- +2.43.0 + diff --git a/custom_ops/cpu_ops/avx_weight_only.cc b/custom_ops/cpu_ops/avx_weight_only.cc deleted file mode 100644 index 1d410156e..000000000 --- a/custom_ops/cpu_ops/avx_weight_only.cc +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "dtype.h" -#include "matmul_helper.h" -#include "my_types.h" -#include "paddle/extension.h" -#include "paddle/phi/core/kernel_registry.h" -template -void AvxCompute(const paddle::Tensor &x, - const paddle::Tensor &weight, - const paddle::Tensor &w_bias, - bool trans, - const std::string alog, - paddle::Tensor &out, - xft::Matrix &quantizedWeight, - xft::Vector &WeightScale, - xft::Vector &WeightZero, - xft::Vector &WeightSum, - MMHelper *mmHelper) { - auto out_data = out.data(); - const float *x_data = reinterpret_cast(x.data()); - const float *bias_data = nullptr; - if (w_bias.initialized()) { - bias_data = reinterpret_cast(w_bias.data()); - } - int m = 1; - for (int i = 0; i < x.shape().size() - 1; i++) { - m = m * x.shape()[i]; - } - int k = x.shape()[x.shape().size() - 1]; - int l = weight.shape()[1]; - int n = weight.shape()[1]; - if (w_bias.initialized()) { - mmHelper->compute_bias(false, - m, - n, - k, - 1.0f, - x_data, - k, - quantizedWeight.Data(), - WeightScale.Data(), - WeightZero.Data(), - WeightSum.Data(), - 0.0f, - out_data, - l, - bias_data); - } else { - mmHelper->compute(false, - m, - n, - k, - 1.0f, - x_data, - k, - quantizedWeight.Data(), - WeightScale.Data(), - WeightZero.Data(), - WeightSum.Data(), - 0.0, - out_data, - l); - } -}; -template -void AvxWeightOnly(const paddle::Tensor &x, - const paddle::Tensor &weight, - const paddle::Tensor &w_bias, - bool trans, - const std::string alog, - paddle::Tensor &out) { - static std::unordered_map *, - xft::Vector *, - xft::Vector *, - xft::Vector *>> - weight_only_hub; - std::stringstream weights_addr; - weights_addr << weight.data() << alog; - std::string weight_only_key = weights_addr.str(); - auto it_created = weight_only_hub.find(weight_only_key); - static MMHelper *mmHelper; - int rows = weight.shape()[0], cols = weight.shape()[1]; - xft::Vector *WeightScale = - new xft::Vector(); // if weight is int8 - xft::Vector *WeightZero = - new xft::Vector(); // if weight is int8 - xft::Vector *WeightSum = - new xft::Vector(); // if weight is int8 - xft::Matrix *quantizedWeight = new xft::Matrix(); - if (it_created == weight_only_hub.end()) { - auto weight_ptr = reinterpret_cast(weight.data()); - xft::Matrix convertedWeight; - mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0); - mmHelper->convertWeight(trans, - rows, - cols, - weight_ptr, - nullptr, - nullptr, - convertedWeight, - *WeightScale, - *WeightZero, - *WeightSum); - quantizedWeight->Resize(rows, cols); - mmHelper->packWeight(trans, convertedWeight, *quantizedWeight); - weight_only_hub[weight_only_key] = std::make_tuple( - quantizedWeight, WeightScale, WeightZero, WeightSum); - AvxCompute(x, - weight, - w_bias, - trans, - alog, - out, - *quantizedWeight, - *WeightScale, - *WeightZero, - *WeightSum, - mmHelper); - } else { - AvxCompute(x, - weight, - w_bias, - trans, - alog, - out, - *(std::get<0>(it_created->second)), - *(std::get<1>(it_created->second)), - *(std::get<2>(it_created->second)), - *(std::get<3>(it_created->second)), - mmHelper); - } -} -std::vector InvokeAvxWeightOnly(const paddle::Tensor &x, - const paddle::Tensor &weight, - const paddle::Tensor &w_bias, - const std::string &alog, - bool trans) { - auto out_shape = x.shape(); - out_shape[out_shape.size() - 1] = weight.shape()[1]; - auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace()); - if (alog == "int8") { - AvxWeightOnly(x, weight, w_bias, trans, alog, out); - } else if (alog == "fp16") { - AvxWeightOnly(x, weight, w_bias, trans, alog, out); - } else { - AvxWeightOnly(x, weight, w_bias, trans, alog, out); - } - return {out}; -} - -std::vector> AvxWeightOnlyInferShape( - std::vector x_shape, - std::vector weigh_shape, - std::vector weigh_bias_shape) { - int m = 1; - for (int i = 0; i < x_shape.size() - 1; i++) { - m = m * x_shape[i]; - } - return {std::vector{m, weigh_shape[1]}}; -} - -std::vector AvxWeightOnlyInferDtype( - paddle::DataType x_dtype, - paddle::DataType weight_dtype, - paddle::DataType weight_bias_dtype) { - return {x_dtype}; -} - -PD_BUILD_STATIC_OP(avx_weight_only) - .Inputs({"x", "weight", "w_bias"}) - .Outputs({"out"}) - .Attrs({"alog: std::string", "trans:bool"}) - .SetKernelFn(PD_KERNEL(InvokeAvxWeightOnly)) - .SetInferShapeFn(PD_INFER_SHAPE(AvxWeightOnlyInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(AvxWeightOnlyInferDtype)); diff --git a/custom_ops/cpu_ops/rebuild_padding.cc b/custom_ops/cpu_ops/rebuild_padding.cc new file mode 100644 index 000000000..8ce533d04 --- /dev/null +++ b/custom_ops/cpu_ops/rebuild_padding.cc @@ -0,0 +1,268 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +void RebuildPaddingCPUImpl(T *output_data, + const T *input_data, + const int *cum_offsets_data, + const int *seq_len_this_time_data, + const int *seq_lens_decoder_data, + const int *seq_lens_encoder_data, + int max_input_length, + int dim_embed, + const int elem_nums) { + for (int i = 0; i < elem_nums; ++i) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; + + if (seq_len_this_time_data[bi] == 0) { + continue; + } + if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { + continue; + } + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; + } + const int ori_token_idx = + bi * max_input_length - cum_offsets_data[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + + output_data[i] = input_data[src_offset]; + } +} + +template +void RebuildAppendPaddingCPUImpl(T *output_data, + const T *input_data, + const int *cum_offsets_data, + const int *seq_len_this_time_data, + const int *seq_lens_decoder_data, + const int *seq_lens_encoder_data, + const int *output_padding_offset_data, + const int max_input_length, + const int dim_embed, + const int64_t output_elem_nums) { + for (int i = 0; i < output_elem_nums; ++i) { + int out_token_id = i / dim_embed; + int ori_token_id = + out_token_id + output_padding_offset_data[out_token_id]; + int bi = ori_token_id / max_input_length; + if (seq_len_this_time_data[bi] == 0 || + (seq_lens_decoder_data[bi] == 0 && + seq_lens_encoder_data[bi] == 0)) { + continue; + } + int seq_id = 0; + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; + } + int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id; + int bias_idx = i % dim_embed; + int src_offset = input_token_id * dim_embed + bias_idx; + output_data[i] = input_data[src_offset]; + } +} + +std::vector RebuildPaddingCPU( + const paddle::Tensor &tmp_out, + const paddle::Tensor &cum_offsets, + const paddle::Tensor &seq_len_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::optional &output_padding_offset, + int max_input_length) { + auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); + auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true); + auto seq_len_this_time_cpu = + seq_len_this_time.copy_to(paddle::CPUPlace(), true); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto seq_lens_encoder_cpu = + seq_lens_encoder.copy_to(paddle::CPUPlace(), true); + paddle::optional output_padding_offset_cpu; + if (output_padding_offset) { + output_padding_offset_cpu = + output_padding_offset->copy_to(paddle::CPUPlace(), true); + } + + int token_num = tmp_out_cpu.shape()[0]; + int dim_embed = tmp_out_cpu.shape()[1]; + int bsz = cum_offsets_cpu.shape()[0]; + + paddle::Tensor out; + if (output_padding_offset_cpu) { + int need_delete_token_num = 0; + for (int i = 0; i < bsz; ++i) { + if (seq_lens_encoder_cpu.data()[i] > 0) { + need_delete_token_num += + seq_lens_encoder_cpu.data()[i] - 1; + } + } + int output_token_num = token_num - need_delete_token_num; + out = paddle::full({output_token_num, dim_embed}, + 0, + tmp_out_cpu.dtype(), + paddle::CPUPlace()); + } else { + out = paddle::full( + {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); + } + + const int *cum_offsets_data = cum_offsets_cpu.data(); + const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); + const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); + int elem_nums = out.numel(); + + if (output_padding_offset_cpu) { + const int *output_padding_offset_data = + output_padding_offset_cpu->data(); + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildAppendPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); + } + } else { + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cum_offsets_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); + } + } + return {out}; +} + +std::vector> RebuildPaddingInferShape( + const std::vector &tmp_out_shape, + const std::vector &cum_offsets_shape, + const std::vector &seq_len_this_time_shape, + const std::vector &seq_lens_decoder_shape, + const std::vector &seq_lens_encoder_shape, + const paddle::optional> &output_padding_offset_shape) { + int64_t dim_embed = tmp_out_shape[1]; + if (output_padding_offset_shape) { + return {{-1, dim_embed}}; + } else { + int64_t bsz = cum_offsets_shape[0]; + return {{bsz, dim_embed}}; + } +} + +std::vector RebuildPaddingInferDtype( + const paddle::DataType &tmp_out_dtype, + const paddle::DataType &cum_offsets_dtype, + const paddle::DataType &seq_len_this_time_dtype, + const paddle::DataType &seq_lens_decoder_dtype, + const paddle::DataType &seq_lens_encoder_dtype, + const paddle::optional &output_padding_offset_dtype) { + return {tmp_out_dtype}; +} + +PD_BUILD_STATIC_OP(rebuild_padding_cpu) + .Inputs({"tmp_out", + "cum_offsets", + "seq_len_this_time", + "seq_lens_decoder", + "seq_lens_encoder", + paddle::Optional("output_padding_offset")}) + .Outputs({"out"}) + .Attrs({"max_input_length: int"}) + .SetKernelFn(PD_KERNEL(RebuildPaddingCPU)) + .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype)); diff --git a/custom_ops/cpu_ops/xft_all_layer.cc b/custom_ops/cpu_ops/xft_all_layer.cc deleted file mode 100644 index 7b24e0b8e..000000000 --- a/custom_ops/cpu_ops/xft_all_layer.cc +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "layers_decoder.h" -#include "paddle/extension.h" -#include "paddle/phi/core/kernel_registry.h" - -std::vector InvokeAllLLaMALayer( - const paddle::Tensor &input, - const std::vector &ln1Gamma, - const std::vector &ln1Beta, - const std::vector &qkvWeight, - const std::vector &qkvBiasWeight, - const std::vector &attnOutWeight, - const std::vector &attnOutBias, - const std::vector &ln2Gamma, - const std::vector &ln2Beta, - const std::vector &gateWeight, - const std::vector &gateBias, - const std::vector &upWeight, - const std::vector &upBias, - const std::vector &downWeight, - const std::vector &downBias, - const paddle::Tensor &pastSeqLen, - const paddle::Tensor ¤tSeqLen, - const paddle::Tensor &step, - int hiddensize, - int totalLayer, - const std::string &computeType, - const std::string &activation, - const std::string &normType, - int attHeadDim, - int attHeadNum, - int kvHeadNum, - int maxPositions, - int maxPosEmbed, - int intermediateSize) { - auto out = paddle::empty_like(input); - auto batchSize = input.shape()[0]; - auto inputSeqLen = input.shape()[1]; - auto past_seq_len = pastSeqLen.data()[0]; - auto cur_seq_len = static_cast(currentSeqLen.data()[0]); - auto step_id = step.data()[0]; - auto output_ptr = reinterpret_cast(out.data()); - auto xft_data_type = xft::DataType::fp16; - if (computeType == "bf16") { - xft_data_type = xft::DataType::bf16; - } else if (computeType == "bf16_int8") { - xft_data_type = xft::DataType::bf16_int8; - } - auto xft_act_type = xft::ActivationType::SILU; - if (activation == "relu") { - xft_act_type = xft::ActivationType::RELU; - } else if (activation == "gelu") { - xft_act_type = xft::ActivationType::GELU; - } else if (activation == "swiglu") { - xft_act_type = xft::ActivationType::SWIGLU; - } - auto xft_norm_type = xft::NormType::RMS; - if (normType == "layernorm") { - xft_norm_type = xft::NormType::LN; - } - auto input_ptr = reinterpret_cast(input.data()); - for (int i = 0; i < totalLayer; ++i) { - auto ln1Gamma_ptr = - reinterpret_cast(ln1Gamma[i].data()); - auto ln1Beta_ptr = - reinterpret_cast(ln1Beta[i].data()); - auto qkvWeight_ptr = - reinterpret_cast(qkvWeight[i].data()); - auto qkvBiasWeight_ptr = - reinterpret_cast(qkvBiasWeight[i].data()); - auto attnOutWeight_ptr = - reinterpret_cast(attnOutWeight[i].data()); - auto ln2Gamma_ptr = - reinterpret_cast(ln2Gamma[i].data()); - auto ln2Beta_ptr = - reinterpret_cast(ln2Beta[i].data()); - auto gate_weight_ptr = - reinterpret_cast(gateWeight[i].data()); - auto up_weight_ptr = - reinterpret_cast(upWeight[i].data()); - auto down_weight_ptr = - reinterpret_cast(downWeight[i].data()); - auto gate_bias_ptr = - reinterpret_cast(gateBias[i].data()); - auto up_bias_ptr = - reinterpret_cast(upBias[i].data()); - auto down_bias_ptr = - reinterpret_cast(downBias[i].data()); - auto attnOutBias_ptr = - reinterpret_cast(attnOutBias[i].data()); - invokeLayerLLaMA( - xft_data_type, // dt - xft_act_type, // at - xft_norm_type, // nt - i, // layerId - totalLayer, // totalLayers - batchSize, // batchSize - inputSeqLen, // inputSeqLen - attHeadDim, // attHeadDim - attHeadNum, // attHeadNum - kvHeadNum, // kvHeadNum - maxPositions, // maxPositions - maxPosEmbed, // maxPosEmbed - past_seq_len, // pastSeqLen - cur_seq_len, // currentSeqLen - step_id, // step - hiddensize, // hiddenSize - intermediateSize, // intermediateSize - reinterpret_cast(output_ptr), // output - hiddensize, // outputStride - input_ptr, // input - hiddensize, // inputStride - ln1Gamma_ptr, // ln1Gamma - ln1Beta_ptr, // ln1Beta - qkvWeight_ptr, // queryWeight - qkvWeight_ptr + hiddensize, // keyWeight - qkvWeight_ptr + hiddensize + kvHeadNum * attHeadDim, // valueWeight - attnOutWeight_ptr, // attnOutWeight - ln2Gamma_ptr, // ln2Gamma - ln2Beta_ptr, // ln2Beta - gate_weight_ptr, - up_weight_ptr, - down_weight_ptr, - qkvBiasWeight_ptr, // queryBias - qkvBiasWeight_ptr + hiddensize, // keyBias - qkvBiasWeight_ptr + hiddensize + - kvHeadNum * attHeadDim, // valueBias - attnOutBias_ptr, // attnOutBias - qkvWeight_ptr, // myqkvWeight - gate_bias_ptr, - up_bias_ptr, - down_bias_ptr, - qkvBiasWeight_ptr); - if (i < totalLayer - 1) { - memcpy(const_cast(input_ptr), - output_ptr, - batchSize * inputSeqLen * hiddensize * sizeof(float)); - } - } - return {out}; -} - -std::vector> AllLLaMALayerInferShape( - std::vector x_shape) { - return {x_shape}; -} - -std::vector AllLLaMALayerInferDtype( - paddle::DataType x_dtype) { - return {x_dtype}; -} - -PD_BUILD_STATIC_OP(xft_llama_all_layer) - .Inputs({ - "x", - paddle::Vec("ln1Gamma"), - paddle::Vec("ln1Beta"), - paddle::Vec("qkvWeight"), - paddle::Vec("qkvBiasWeight"), - paddle::Vec("attnOutWeight"), - paddle::Vec("attnOutBias"), - paddle::Vec("ln2Gamma"), - paddle::Vec("ln2Beta"), - paddle::Vec("gateWeight"), - paddle::Vec("gateBias"), - paddle::Vec("upWeight"), - paddle::Vec("upBias"), - paddle::Vec("downWeight"), - paddle::Vec("downBias"), - "pastSeqLen", - "currentSeqLen", - "step", - }) - .Outputs({"out"}) - .Attrs({"hiddensize :int", - "totalLayer :int", - "computeType : std::string", - "activation :std::string", - "normType :std::string", - "attHeadDim: int", - "attHeadNum: int", - "kvHeadNum: int", - "maxPositions: int", - "maxPosEmbed: int", - "intermediateSize: int"}) - .SetKernelFn(PD_KERNEL(InvokeAllLLaMALayer)) - .SetInferShapeFn(PD_INFER_SHAPE(AllLLaMALayerInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(AllLLaMALayerInferDtype)); diff --git a/custom_ops/cpu_ops/xft_greedy_search.cc b/custom_ops/cpu_ops/xft_greedy_search.cc deleted file mode 100644 index 4ee78a768..000000000 --- a/custom_ops/cpu_ops/xft_greedy_search.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include -#include -#include -#include "paddle/extension.h" - -void greedy_search(const float *probs, - int64_t *next_token_ids, - int bsz, - int vocab_size) { - int numThreads = 0; -#pragma omp parallel - { - int tid = omp_get_thread_num(); - if (tid == 0) { - numThreads = omp_get_num_threads(); - } - } - float maxVals[bsz]; - - // Small batch size (each sample can have at least 2 threads) - if (numThreads / bsz >= 2) { - int thrPerSample = numThreads / bsz; - int sizePerThr = (vocab_size + thrPerSample - 1) / thrPerSample; - int maxIndices[bsz * thrPerSample]; - float maxValues[bsz * thrPerSample]; - - // TODO: if size is small, possible to cause out of boundary -#pragma omp parallel for collapse(2) - for (int b = 0; b < bsz; ++b) { - for (int t = 0; t < thrPerSample; ++t) { - int start = t * sizePerThr; - int end = (start + sizePerThr) > vocab_size - ? vocab_size - : (start + sizePerThr); - const float *p = probs + b * vocab_size; - int maxIdx = start; - float maxVal = p[start]; - for (int off = start + 1; off < end; ++off) { - if (p[off] > maxVal) { - maxVal = p[off]; - maxIdx = off; - } - } - - // False sharing happens, but since only one time, not avoided - maxIndices[b * thrPerSample + t] = maxIdx; - maxValues[b * thrPerSample + t] = maxVal; - } - } - - // Local reduction - for (int i = 0; i < bsz; ++i) { - int *pIndices = maxIndices + i * thrPerSample; - float *pValues = maxValues + i * thrPerSample; - int maxIdx = pIndices[0]; - float maxVal = pValues[0]; - for (int j = 1; j < thrPerSample; ++j) { - if (pValues[j] > maxVal) { - maxVal = pValues[j]; - maxIdx = pIndices[j]; - } - } - next_token_ids[i] = maxIdx; - maxVals[i] = maxVal; - } - } - - // Each thread handle one sample (one row) - else { -#pragma omp parallel for - for (int i = 0; i < bsz; ++i) { - int maxId = 0; - const float *p = probs + i * vocab_size; - float maxVal = p[0]; - for (int j = 1; j < vocab_size; ++j) { - if (p[j] > maxVal) { - maxVal = p[j]; - maxId = j; - } - } - next_token_ids[i] = maxId; - maxVals[i] = maxVal; - } - } - return; -} -std::vector XftGreedySearch(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto next_tokens = - paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place()); - - greedy_search(probs.data(), - const_cast(next_tokens.data()), - bsz, - vocab_size); - return {next_tokens}; -} -std::vector> XftGreedySearchInferShape( - const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - return {{bsz, 1}}; -} -std::vector XftGreedySearchInferDtype( - const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64}; -} -PD_BUILD_STATIC_OP(xft_greedy_search) - .Inputs({"probs"}) - .Outputs({"next_tokens_ids"}) - .SetInferShapeFn(PD_INFER_SHAPE(XftGreedySearchInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(XftGreedySearchInferDtype)) - .SetKernelFn(PD_KERNEL(XftGreedySearch)); diff --git a/custom_ops/gpu_ops/air_topp_sampling.cu b/custom_ops/gpu_ops/air_topp_sampling.cu deleted file mode 100644 index 92318b38d..000000000 --- a/custom_ops/gpu_ops/air_topp_sampling.cu +++ /dev/null @@ -1,1612 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "helper.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/core/stream.h" - -#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") - -#define FINAL_MASK 0xFFFFFFFF - -#define FIXED_BLOCK_DIM_BASE(dim, ...) \ - case (dim): { \ - constexpr auto kBlockDim = (dim); \ - __VA_ARGS__; \ - } break - - -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) - -template -struct alignas(128) Counter -{ - T const* in; - IdxT const* inIdx; - - IdxT oriLen; - - AccT sum; - IdxT len; - float p; - IdxT previousLen; - typename cub::Traits::UnsignedBits kthValueBits; - - alignas(128) IdxT filterCnt; - alignas(128) uint32_t finishedBlockCnt; -}; - -template -constexpr __host__ __device__ IntType ceilDiv(IntType a, IntType b) -{ - return (a + b - 1) / b; -} - -template -constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) -{ - return ceilDiv(a, b) * b; -} - -/** - * This function calculate the bufLen, which is the size of buffer. - * When the number of candidates for next pass exceeds the bufLen, we choose not to store the candidates. Otherwise, we - * will load candidates from the original input data. - */ -template -__host__ __device__ IdxT calcBufLen(IdxT len) -{ - IdxT constexpr ratio = 2 + sizeof(IdxT) * 2 / sizeof(T); - IdxT bufLen = len / (ratio * 8); - bufLen = alignTo(bufLen, 256); - return bufLen; -} - -template -__host__ __device__ constexpr int calcNumPasses() -{ - return ceilDiv(sizeof(T) * 8, BitsPerPass); -} - -template -__device__ typename cub::Traits::UnsignedBits twiddleIn(T key, bool selectMin) -{ - auto bits = reinterpret_cast::UnsignedBits&>(key); - bits = cub::Traits::TwiddleIn(bits); - if (!selectMin) - { - bits = ~bits; - } - return bits; -} - -template -__device__ T twiddleOut(typename cub::Traits::UnsignedBits bits, bool selectMin) -{ - if (!selectMin) - { - bits = ~bits; - } - bits = cub::Traits::TwiddleOut(bits); - return reinterpret_cast(bits); -} - -template -__host__ __device__ constexpr int calcNumBuckets() -{ - return 1 << BitsPerPass; -} - -template -__device__ constexpr int calcStartBit() -{ - constexpr int tmpBit = sizeof(T) * 8 - (Pass + 1) * BitsPerPass; - - constexpr int startBit = tmpBit < 0 ? 0 : tmpBit; - return startBit; -} - -template -__device__ constexpr uint32_t calcMask() -{ - static_assert(BitsPerPass <= 31); - constexpr int numBits = calcStartBit() - calcStartBit(); - return (1 << numBits) - 1; -} - -/** - * Find the bucket based on the radix - */ -template -__device__ int calcBucket(T x, int startBit, uint32_t mask, bool selectMin) -{ - return (twiddleIn(x, selectMin) >> startBit) & mask; -} - -/** - * Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description) - */ -template -__device__ void scan(IdxT volatile* histogram, IdxT* histogramOut) -{ - int constexpr numBuckets = calcNumBuckets(); - if constexpr (numBuckets >= BlockSize) - { - static_assert(numBuckets % BlockSize == 0); - int constexpr itemsPerThread = numBuckets / BlockSize; - typedef cub::BlockLoad BlockLoad; - typedef cub::BlockStore BlockStore; - typedef cub::BlockScan BlockScan; - - __shared__ union - { - typename BlockLoad::TempStorage load; - typename BlockScan::TempStorage scan; - typename BlockStore::TempStorage store; - } tempStorage; - - IdxT threadData[itemsPerThread]; - - BlockLoad(tempStorage.load).Load(histogram, threadData); - __syncthreads(); - - BlockScan(tempStorage.scan).InclusiveSum(threadData, threadData); - __syncthreads(); - - BlockStore(tempStorage.store).Store(histogramOut, threadData); - } - else - { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage tempStorage; - - IdxT threadData = 0; - if (threadIdx.x < numBuckets) - { - threadData = histogram[threadIdx.x]; - } - - BlockScan(tempStorage).InclusiveSum(threadData, threadData); - __syncthreads(); - - if (threadIdx.x < numBuckets) - { - histogramOut[threadIdx.x] = threadData; - } - } -} - -template -__device__ __forceinline__ void filterAndHistogram(const T *in_buffer, - const int *in_idx_buffer, - T *out_buffer, - int *out_idx_buffer, - T *out_scores, - int64_t *out_ids, - int previous_len, - Counter *counter, - T *histogram, - int *count_histogram, - T *histogram_shm, - int *count_histogram_shm, - const bool early_stop) { - // scan and filter - constexpr int start_bit = calcStartBit(); - const uint32_t mask = calcMask(); - constexpr int VecSize = 16 / sizeof(T); - const int bid = blockIdx.y, tid = threadIdx.x; - using VecT = uint4; - union { - VecT v; - T array[VecSize]; - } vec; - for (int i = (blockIdx.x * blockDim.x + threadIdx.x) ; i < ceilDiv(previous_len, VecSize); i += blockDim.x * gridDim.x) { - vec.v = reinterpret_cast(in_buffer)[i]; - if constexpr (Pass == 0) { -#pragma unroll - for (int j = 0; j < VecSize; j++) { - if (i * VecSize + j < previous_len) { - int bucket = calcBucket(vec.array[j], start_bit, mask, false); - atomicAdd(histogram_shm + bucket, vec.array[j]); - atomicAdd(count_histogram_shm + bucket, 1); - } - } - } else { - int *filter_cnt = &counter->filterCnt; - const auto kthValueBits = counter->kthValueBits; - constexpr int previousStartBit = calcStartBit(); -#pragma unroll - for (int j = 0; j < VecSize; j++) { - const int idx = i * VecSize + j; - if (idx < previous_len) { - const auto previousBits = (twiddleIn(vec.array[j], false) >> previousStartBit) << previousStartBit; - if (previousBits == kthValueBits) { - if (early_stop) { - const int pos = in_idx_buffer ? in_idx_buffer[idx] : idx; - out_scores[bid] = vec.array[j]; - out_ids[bid] = pos; - } - if (out_buffer) { - int pos = atomicAdd(filter_cnt, 1); - out_buffer[pos] = vec.array[j]; - out_idx_buffer[pos] = in_idx_buffer ? in_idx_buffer[idx] : idx; - } - int bucket = calcBucket(vec.array[j], start_bit, mask, false); - atomicAdd(histogram_shm + bucket, vec.array[j]); - atomicAdd(count_histogram_shm + bucket, 1); - } - } - } - } - } - __syncthreads(); - if (early_stop) { - return; - } - for (int i = tid; i < NumBuckets; i += blockDim.x) { - if (count_histogram_shm[i] > 0) { - atomicAdd(histogram + i, histogram_shm[i]); - atomicAdd(count_histogram + i, count_histogram_shm[i]); - } - } -} - -template -__global__ void air_topp_sampling(Counter *counters, - T *histograms, - int *count_histograms, - T *out, - int64_t *ids, - T *buf1, - int *idx_buf1, - T *buf2, - int *idx_buf2, - int* count_iter, - int* count_iter_begin, - const int buf_len) { - - /*** - * calc - filter - scan -find - * TODO: calc - scan - find - filter - ***/ - const int bid = blockIdx.y; - if (count_iter_begin[bid] == count_iter[bid + 1]) { - // topk - return; - } - - const int tid = threadIdx.x; - auto counter = counters + bid; - - T current_sum; - int previous_len, current_len; - if constexpr (Pass == 0) { - current_sum = 0; - previous_len = counter->len; - current_len = counter->len; - } else { - current_sum = counter->sum; - previous_len = counter->previousLen; - current_len = counter->len; - } - if (current_len == 0) { - return; - } - const bool early_stop = (current_len == 1); - const T *in_buf = nullptr; - const int *in_idx_buf = nullptr; - T *out_buf = nullptr; - int *out_idx_buf = nullptr; - const int buf_offset = bid * buf_len; - if constexpr (Pass == 0) { - in_buf = counter->in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if constexpr (Pass == 1) { - in_buf = counter->in; - in_idx_buf = nullptr; - out_buf = buf1 + buf_offset; - out_idx_buf = idx_buf1 + buf_offset; - } else { - in_buf = buf1 + buf_offset; - in_idx_buf = idx_buf1 + buf_offset; - out_buf = buf2 + buf_offset; - out_idx_buf = idx_buf2 + buf_offset; - } - - if (Pass == 0 || Pass == 1 || previous_len > buf_len) { - previous_len = counter->oriLen; - in_buf = counter->in; - in_idx_buf = nullptr; - } - if (Pass == 0 || current_len > buf_len) { - out_buf = nullptr; - out_idx_buf = nullptr; - } - - auto histogram = histograms + bid * NumBuckets; - auto count_histogram = count_histograms + bid * NumBuckets; - __shared__ T histogram_shm[NumBuckets]; - __shared__ int count_histogram_shm[NumBuckets]; - for (int i = tid; i < NumBuckets; i += blockDim.x) { - histogram_shm[i] = 0; - count_histogram_shm[i] = 0; - } - __syncthreads(); - - filterAndHistogram( - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out, - ids, - previous_len, - counter, - histogram, - count_histogram, - histogram_shm, - count_histogram_shm, - early_stop - ); - __syncthreads(); - __threadfence(); - - // find last block - bool isLastBlock = false; - if (threadIdx.x == 0) { - uint32_t finished = atomicInc(&counter->finishedBlockCnt, gridDim.x - 1); - isLastBlock = (finished == (gridDim.x - 1)); - } - - if (__syncthreads_or(isLastBlock)) { - if (early_stop) { - if (threadIdx.x == 0) { - counter->previousLen = 0; - counter->len = 0; - } - return; - } - - // scan/find - constexpr int WARP_SIZE = 32; - constexpr int WARP_COUNT = NumBuckets / WARP_SIZE; - namespace cg = cooperative_groups; - cg::thread_block block = cg::this_thread_block(); - cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - __shared__ T warpSum[WARP_COUNT]; - __shared__ cuda::atomic blockSum; - for (int i = tid; i < WARP_COUNT; i += BlockSize) { - warpSum[i] = 0; - } - if (tid == 0) { - blockSum = 0; - } - __syncthreads(); - // Acquire the summation of each 32 buckets - for (int i = threadIdx.x; i < NumBuckets; i += BlockSize) { - reduce_store_async(warp, warpSum + i / WARP_SIZE, histogram[i], cg::plus{}); - } - __syncthreads(); - // Acquire the summation of all the 2048 buckets - if (threadIdx.x < WARP_SIZE) { - reduce_store_async(warp, blockSum, warpSum[threadIdx.x], cg::plus{}); - reduce_update_async(warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], cg::plus{}); - } - __syncthreads(); - - if constexpr (Pass == 0) { - current_sum = blockSum * counter->p; - } - - if (tid == 0) { - T prev = 0; - - // Add 32 elements each step - int iStep = 0; - int targetStep = 0; - for (; iStep < WARP_COUNT; iStep++) { - if (warpSum[iStep]) { - targetStep = iStep; - if ((prev + warpSum[iStep]) >= current_sum) { - break; - } - prev += warpSum[iStep]; - } - } - - int targetIdx = 0; - for (int i = targetStep * WARP_SIZE; i < NumBuckets; i++) { - if (count_histogram[i]) { - targetIdx = i; - if ((prev + histogram[i]) >= current_sum) { - break; - } - prev += histogram[i]; - } - } - counter->sum = current_sum - prev; // how many values still are there to find - counter->len = count_histogram[targetIdx]; // cur - prev; // number of values in next pass - typename cub::Traits::UnsignedBits bucket = targetIdx; - int startBit = calcStartBit(); - counter->kthValueBits |= bucket << startBit; - } - __syncthreads(); - constexpr int numPasses = calcNumPasses(); - if constexpr (Pass != numPasses - 1) { - for (int i = tid; i < NumBuckets; i += BlockSize) { - histogram[i] = 0; - count_histogram[i] = 0; - } - } - if (tid == 0) { - // recover - counter->previousLen = current_len; - counter->filterCnt = 0; - } - if constexpr (Pass == numPasses - 1) { - const auto kthValueBits = counter->kthValueBits; - const auto equal_value = twiddleOut(kthValueBits, false); - - const T *last_data = out_buf ? out_buf : in_buf; - const int *last_idx_data = out_idx_buf ? out_idx_buf : in_idx_buf; - const int last_len = out_buf ? current_len : counter->oriLen; - for (int i = tid; i < last_len; i += BlockSize) { - if (last_data[i] == equal_value) { - out[bid] = equal_value; - ids[bid] = last_idx_data ? last_idx_data[i] : i; - } - } - } - } -} - -template -__global__ void air_topp_init(Counter *counters, - T *histograms, - int *count_histograms, - const T *in, - const T *ps, - curandState_t* curandstate, - const int bsz, - const int vocab_size, - const int buf_len, - const int num_buckets) { - const int bid = blockIdx.x; - const int tid = threadIdx.x; - Counter *counter_now = counters + bid; - T *histogram_now = histograms + bid * num_buckets; - int *count_histogram_now = count_histograms + bid * num_buckets; - const int offset = bid * vocab_size; - if (tid == 0) { - counter_now->in = in + offset; - - counter_now->len = vocab_size; - counter_now->oriLen = vocab_size; - counter_now->previousLen = vocab_size; - - const T p = ps[bid]; - const T rand_p = curand_uniform(curandstate + bid) * p; - counter_now->p = rand_p; - - counter_now->sum = 0; - - counter_now->kthValueBits = 0; - counter_now->filterCnt = 0; - counter_now->finishedBlockCnt = 0; - } - for (int i = tid; i < num_buckets; i += blockDim.x) { - histogram_now[i] = 0; - count_histogram_now[i] = 0; - } -} - -struct SegmentOffsetIter { - explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} - - __host__ __device__ __forceinline__ int operator()(int idx) const { - return idx * num_cols_; - } - - int num_cols_; -}; - -template -struct Pair { - __device__ __forceinline__ Pair() {} - __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} - - __device__ __forceinline__ void set(T value, int id) { - this->v = value; - this->id = id; - } - - __device__ __forceinline__ void operator=(const Pair& in) { - v = in.v; - id = in.id; - } - - __device__ __forceinline__ bool operator<(const T value) const { - return (static_cast(v) < static_cast(value)); - } - - __device__ __forceinline__ bool operator>(const T value) const { - return (static_cast(v) > static_cast(value)); - } - __device__ __forceinline__ bool operator<(const Pair& in) const { - return (static_cast(v) < static_cast(in.v)) || - ((static_cast(v) == static_cast(in.v)) && - (id > in.id)); - } - - __device__ __forceinline__ bool operator>(const Pair& in) const { - return (static_cast(v) > static_cast(in.v)) || - ((static_cast(v) == static_cast(in.v)) && - (id < in.id)); - } - - T v; - int id; -}; - -inline int div_up(int a, int n) { return (a + n - 1) / n; } - -template -__device__ __forceinline__ void AddTo(Pair topk[], - const Pair& p, - int beam_size) { - for (int k = beam_size - 2; k >= 0; k--) { - if (topk[k] < p) { - topk[k + 1] = topk[k]; - } else { - topk[k + 1] = p; - return; - } - } - topk[0] = p; -} - -template -__device__ __forceinline__ void GetTopK(Pair topk[], - const T* src, - int idx, - int dim, - int beam_size) { - while (idx < dim) { - if (topk[beam_size - 1] < src[idx]) { - Pair tmp(src[idx], idx); - AddTo(topk, tmp, beam_size); - } - idx += BlockSize; - } -} - -template -__device__ __forceinline__ void GetTopK(Pair topk[], - const T* src, - int idx, - int dim, - const Pair& max, - int beam_size) { - while (idx < dim) { - if (topk[beam_size - 1] < src[idx]) { - Pair tmp(src[idx], idx); - if (tmp < max) { - AddTo(topk, tmp, beam_size); - } - } - idx += BlockSize; - } -} - -template -__device__ __forceinline__ void ThreadGetTopK(Pair topk[], - int* beam, - int beam_size, - const T* src, - bool* firstStep, - bool* is_empty, - Pair* max, - int dim, - const int tid) { - if (*beam > 0) { - int length = (*beam) < beam_size ? *beam : beam_size; - if (*firstStep) { - *firstStep = false; - GetTopK(topk, src, tid, dim, length); - } else { - for (int k = 0; k < MaxLength; k++) { - if (k < MaxLength - (*beam)) { - topk[k] = topk[k + *beam]; - } else { - topk[k].set(std::numeric_limits::min(), -1); - } - } - if (!(*is_empty)) { - GetTopK( - topk + MaxLength - *beam, src, tid, dim, *max, length); - } - } - - *max = topk[MaxLength - 1]; - if ((*max).id == -1) *is_empty = true; - *beam = 0; - } -} - -template -__forceinline__ __device__ T -CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { - return __shfl_down_sync(mask, val, static_cast(delta), width); -} - -template -__forceinline__ __device__ Pair WarpReduce(Pair input) { -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - T tmp_val = - CudaShuffleDownSync(FINAL_MASK, input.v, offset, 32); - int tmp_id = - CudaShuffleDownSync(FINAL_MASK, input.id, offset, 32); - if (static_cast(input.v) < static_cast(tmp_val)) { - input.v = tmp_val; - input.id = tmp_id; - } - } - return input; -} - -template -__device__ __forceinline__ void BlockReduce(Pair shared_max[], - Pair topk[], - Pair beam_max[], - int* beam, - int* k, - int* count, - const int tid, - const int wid, - const int lane) { - while (true) { - __syncthreads(); - Pair input_now = topk[0]; - input_now = WarpReduce(input_now); - - if (lane == 0) { - shared_max[wid] = input_now; - } - __syncthreads(); - input_now = (tid < BlockSize / 32) - ? shared_max[lane] - : Pair(std::numeric_limits::min(), -1); - if (wid == 0) { - input_now = WarpReduce(input_now); - if (lane == 0) shared_max[0] = input_now; - } - __syncthreads(); - if (tid == 0) { - beam_max[*count] = shared_max[0]; - (*count)++; - } - int tid_max = shared_max[0].id % BlockSize; - if (tid == tid_max) { - (*beam)++; - } - if (--(*k) == 0) break; - __syncthreads(); - - if (tid == tid_max) { - if (*beam < MaxLength) { - topk[0] = topk[*beam]; - } - } - - if (MaxLength < 5) { - if (*beam >= MaxLength) break; - } else { - unsigned mask = 0u; - mask = __ballot_sync(FINAL_MASK, true); - if (tid_max / 32 == wid) { - if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) - break; - } - } - } -} - -template -__device__ inline T exponential_transform(T val, T lambda) { -#if defined(__NVCC__) || defined(__HIPCC__) - T log = -std::numeric_limits::epsilon() / 2; - if (val < static_cast(1.) - std::numeric_limits::epsilon() / 2) { - if (std::is_same::value) { - log = logf(val); - } else { - log = __logf(val); - } - } - return static_cast(-1.0) / lambda * log; -#else - return static_cast(-1.0) / lambda * std::log(static_cast(1.0) - val); -#endif -} - -template -__global__ void KeMatrixTopPBeamTopK(const T* src, - const T* threshold, - curandState_t* states, - T* top_ps, - int64_t* out_id, // topk id - T* out_val, // topk val - int64_t* topk_ids, - T* topk_scores, - int vocab_size, - int* count_iter, - int* count_iter_begin, - const int k, - const bool need_batch_random) { - const int tid = threadIdx.x; - const int wid = tid / 32; - const int lane = tid % 32; - const int bid = blockIdx.x; - const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; - - int top_num = TopPBeamTopK; - float top_p_num = static_cast(top_ps[bid]); - const int offset = bid * vocab_size; - int64_t *topk_ids_now = topk_ids + bid * k; - T* topk_scores_now = topk_scores + bid * k; - - __shared__ Pair shared_max[BlockSize / 32]; - __shared__ Pair beam_max[TopPBeamTopK]; - - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; - __shared__ int count; - - if (tid == 0) { - count = 0; - } - - for (int j = 0; j < MaxLength; j++) { - topk[j].set(std::numeric_limits::min(), -1); - } - - while (top_num) { - ThreadGetTopK(topk, - &beam, - TopPBeamTopK, - src + offset, - &firststep, - &is_empty, - &max, - vocab_size, - tid); - BlockReduce( - shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); - } - if (tid == 0) { - // printf("offset: %d\n", (int)seed_offset); - count_iter_begin[bid] = count_iter[bid]; - float top_p = top_ps[bid]; - float sum_prob = 0.0f; - bool flag = false; - float max_val = 0.f; - int max_id = -1; - for (int i = 0; i < TopPBeamTopK; i++) { - if (i < k) { - topk_ids_now[i] = static_cast(beam_max[i].id); - topk_scores_now[i] = beam_max[i].v; - } - if (!flag) { - float val = static_cast(beam_max[i].v); - sum_prob += val; - float random_ratio = exponential_transform(curand_uniform(states + bid), 1.0f); - // for (int t = 0; t < 5; t++) { - // float tmp_random_ratio = curand_uniform(&state); - // printf("step: %d, tmp_random_ratio: %f\n", t, tmp_random_ratio); - // } - float random_val = (val >= threshold_now ? val : 0.f) / random_ratio; - // printf("random_ratio: %f, val: %f, random_val: %f\n", random_ratio, val, random_val); - if (max_val < random_val) { - max_val = random_val; - max_id = i; - } - if (sum_prob >= top_p) { - flag = true; - count_iter_begin[bid] += 1; - if (max_id == -1) { - // don't sample low score token - out_id[bid] = static_cast(beam_max[0].id); - out_val[bid] = beam_max[0].v; - } else { - out_id[bid] = static_cast(beam_max[max_id].id); - out_val[bid] = beam_max[max_id].v; - } - } - } - if (flag && i >= k - 1) { - break; - } - } - } -} - -template -__global__ void KeMatrixTopPBeamTopKFt(const T* src, - const T* threshold, - curandState_t* states, - T* top_ps, - int64_t* out_id, // topk id - T* out_val, // topk val - int64_t* topk_ids, - T* topk_scores, - int vocab_size, - int* count_iter, - int* count_iter_begin, - const int k, - const bool need_batch_random) { - const int tid = threadIdx.x; - const int wid = tid / 32; - const int lane = tid % 32; - const int bid = blockIdx.x; - const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; - - int top_num = TopPBeamTopK; - float top_p_num = static_cast(top_ps[bid]); - int64_t* topk_ids_now = topk_ids + bid * k; - T* topk_scores_now = topk_scores + bid * k; - - __shared__ Pair shared_max[BlockSize / 32]; - __shared__ Pair beam_max[TopPBeamTopK]; - - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; - __shared__ int count; - - if (tid == 0) { - count = 0; - } - - for (int j = 0; j < MaxLength; j++) { - topk[j].set(std::numeric_limits::min(), -1); - } - - while (top_num) { - ThreadGetTopK(topk, - &beam, - TopPBeamTopK, - src + bid * vocab_size, - &firststep, - &is_empty, - &max, - vocab_size, - tid); - BlockReduce( - shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); - } - if (tid == 0) { - count_iter_begin[bid] = count_iter[bid]; - float rand_top_p = curand_uniform(states + bid) * top_p_num; - top_ps[bid] = (T)rand_top_p; - float sum_prob = 0.0f; - bool flag = false; - for (int i = 0; i < TopPBeamTopK; i++) { - if (i < k) { - topk_ids_now[i] = static_cast(beam_max[i].id); - topk_scores_now[i] = beam_max[i].v; - } - if (!flag) { - float val = static_cast(beam_max[i].v); - sum_prob += val; - if (sum_prob >= rand_top_p) { - flag = true; - count_iter_begin[bid] += 1; - if (val < threshold_now) { - // don't sample low score token - int start_id = i == 0 ? 0 : i - 1; - for (int j = start_id; j >= 0; j--) { - float val_now = static_cast(beam_max[j].v); - if (val_now >= threshold_now || j == 0) { - out_id[bid] = static_cast(beam_max[j].id); - out_val[bid] = beam_max[j].v; - break; - } - } - } else { - out_id[bid] = static_cast(beam_max[i].id); - out_val[bid] = beam_max[i].v; - } - } - } - if (flag && i >= k - 1) { - break; - } - } - } -} - -__global__ void AirToppSetCountIter(int* count_iter, int num) { - int tid = threadIdx.x; - int bid = blockIdx.x; - int idx = bid * blockDim.x + tid; - for (int i = idx; i < num; i += gridDim.x * blockDim.x) { - count_iter[i] = i; - } -} - -template -__global__ void FillIndex(T* indices, T num_rows, T num_cols) { - int col_id = threadIdx.x; - int row_id = blockIdx.x; - - for (T j = row_id; j < num_rows; j += gridDim.x) { - for (T i = col_id; i < num_cols; i += blockDim.x) { - indices[j * num_cols + i] = i; - } - } -} - -template -void DispatchKeMatrixTopPBeamTopK(const T* src, - const T* threshold, - curandState_t* states, - T* top_ps, - int64_t* out_id, // topk id - T* out_val, // topk val - int64_t* topk_ids, - T* topk_scores, - int vocab_size, - int* count_iter, - int* count_iter_begin, - const int k, - const int bs, - const bool need_batch_random, - const std::string& mode, - cudaStream_t stream) { - int BlockSize = GetBlockSize(vocab_size); - if (mode == "truncated") { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopKFt - <<>>( - src, - threshold, - states, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW("the input data shape has error in the topp_beam_topk kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopK - <<>>( - src, - threshold, - states, - top_ps, - out_id, - out_val, - topk_ids, - topk_scores, - vocab_size, - count_iter, - count_iter_begin, - k, - need_batch_random)); - default: - PD_THROW("the input data shape has error in the topp_beam_topk kernel."); - } - } -} - -struct BlockPrefixCallbackOp { - // Running prefix - float running_total; - // Constructor - __device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ float operator()(float block_aggregate) - { - float old_prefix = running_total; - running_total += block_aggregate; - return old_prefix; - } -}; - -template -__global__ void topp_sampling(T* sorted_probs, - int64_t* sorted_id, - T* out_val, - int64_t* out_id, - const T* top_ps, - const T* threshold, - curandState_t * states, - const int p_num, - const int vocab_size, - const bool need_batch_random, - int* count_iter, - int* count_iter_begin) { - __shared__ int stop_shared; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - constexpr int NUM_WARPS = BLOCK_SIZE / 32; - const int lane_id = tid % 32; - const int warp_id = tid / 32; - const float p_t = static_cast(top_ps[bid]); - const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; - if (tid == 0) { - stop_shared = 0; - } - if (count_iter_begin[bid] == count_iter[bid + 1]) { - // topk - return; - } - - typedef cub::BlockScan BlockScan; - typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ typename BlockReduce::TempStorage temp_storage_reduce; - - // Initialize running total - BlockPrefixCallbackOp prefix_op(0); - - int offset = bid * vocab_size; - int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int i_activate = 0; - float thread_offset = 0; - Pair max_thread_pair(static_cast(0.), -1); - for (int i = tid; i < end; i += BLOCK_SIZE) { - float thread_count = - (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; - BlockScan(temp_storage) - .InclusiveSum(thread_count, thread_offset, prefix_op); - - if (thread_offset < p_t || (thread_offset >= p_t && thread_offset - thread_count < p_t)) { - float random_ratio = exponential_transform(curand_uniform(states + bid), 1.0f); - float tmp_val = (thread_count >= threshold_now ? thread_count : 0.f) / random_ratio; - if (static_cast(max_thread_pair.v) < tmp_val) { - max_thread_pair.set(static_cast(tmp_val), i); - } - uint32_t activate_mask = __ballot_sync(FINAL_MASK, p_t <= thread_offset); - - i_activate = i; - if (activate_mask != 0) { - if (lane_id == 0) { - atomicAdd(&stop_shared, 1); - } - } - __syncthreads(); - if (stop_shared > 0) { - break; - } - } - __syncthreads(); - if (stop_shared == 0) { - if (tid == 0) { - out_id[bid] = sorted_id[offset]; - out_val[bid] = sorted_probs[offset]; - } - return; - } - Pair max_pair = BlockReduce(temp_storage_reduce).Reduce(max_thread_pair, MaxOp>()); - if (tid == 0) { - if (max_pair.id == -1) { - max_pair.id = 0; - } - out_id[bid] = sorted_id[offset + max_pair.id]; - out_val[bid] = sorted_probs[offset + max_pair.id]; - } - } -} - -template -__global__ void topp_sampling_ft(T* sorted_probs, - int64_t* sorted_id, - T* out_val, - int64_t* out_id, - const T* top_ps, - const T* threshold, - curandState_t* states, - const int p_num, - const int vocab_size, - const bool need_batch_random, - int* count_iter, - int* count_iter_begin) { - __shared__ int stop_shared; - __shared__ float rand_p; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - constexpr int NUM_WARPS = BLOCK_SIZE / 32; - const int lane_id = tid % 32; - const int warp_id = tid / 32; - const float p_t = static_cast(top_ps[bid]); - const float threshold_now = threshold ? static_cast(threshold[bid]) : 0.f; - if (tid == 0) { - stop_shared = 0; - rand_p = p_t; - } - if (count_iter_begin[bid] == count_iter[bid + 1]) { - // topk - return; - } - - typedef cub::BlockScan BlockScan; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ typename BlockReduce::TempStorage temp_storage_reduce; - __shared__ uint32_t selected_shared[NUM_WARPS]; - int threshold_id = 0; - - // Initialize running total - BlockPrefixCallbackOp prefix_op(0); - - if (lane_id == 0) { - selected_shared[warp_id] = 0; - } - __syncthreads(); - - int offset = bid * vocab_size; - int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int i_activate = 0; - float thread_offset = 0; - for (int i = tid; i < end; i += BLOCK_SIZE) { - float thread_count = - (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; - if (i < vocab_size && thread_count >= threshold_now) { - threshold_id = i; - } - BlockScan(temp_storage) - .InclusiveSum(thread_count, thread_offset, prefix_op); - - uint32_t activate_mask = __ballot_sync(FINAL_MASK, rand_p <= thread_offset); - - i_activate = i; - if (activate_mask != 0) { - if (lane_id == 0) { - atomicAdd(&stop_shared, 1); - selected_shared[warp_id] = activate_mask; - } - } - __syncthreads(); - if (stop_shared > 0) { - break; - } - } - __syncthreads(); - if (stop_shared == 0) { - if (tid == 0) { - out_id[bid] = sorted_id[offset]; - out_val[bid] = sorted_probs[offset]; - } - return; - } - bool skip = (selected_shared[warp_id] > 0) ? false : true; - for (int i = 0; i < warp_id; i++) { - if (selected_shared[i] != 0) { - // If the previous has stopped, skip the current warp - skip = true; - } - } - if (!skip) { - int active_lane_id = - 32 - __popc(selected_shared[warp_id]); // first not 0 - if (lane_id == active_lane_id) { - float val = static_cast(sorted_probs[offset + i_activate]); - if (val < threshold_now) { - // don't sample low score token - int max_id = BlockReduce(temp_storage_reduce).Reduce(threshold_id, MaxOp()); - curandStatePhilox4_32_10_t rng; - curand_init(bid * blockDim.x + tid, tid, 0, &rng); - int random_id = curand(&rng) % (max_id + 1); - out_id[bid] = sorted_id[offset + random_id]; - out_val[bid] = sorted_probs[offset + random_id]; - } else { - out_id[bid] = sorted_id[offset + i_activate]; - out_val[bid] = sorted_probs[offset + i_activate]; - } - } - } -} - -template -void DispatchTopPSampling(T* sorted_probs, - int64_t* sorted_id, - T* out_val, - int64_t* out_id, - const T* top_ps, - const T* threshold, - curandState_t* states, - const int p_num, - const int vocab_size, - const int bs, - const bool need_batch_random, - int* count_iter, - int* count_iter_begin, - const std::string& mode, - cudaStream_t stream) { - int BlockSize = GetBlockSize(vocab_size); - if (mode == "truncated") { - switch (BlockSize) { - FIXED_BLOCK_DIM(topp_sampling_ft - <<>>( - sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - states, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW("the input data shape has error in the topp_sampling kernel."); - } - } else { - switch (BlockSize) { - FIXED_BLOCK_DIM(topp_sampling - <<>>( - sorted_probs, - sorted_id, - out_val, - out_id, - top_ps, - threshold, - states, - p_num, - vocab_size, - need_batch_random, - count_iter, - count_iter_begin)); - default: - PD_THROW("the input data shape has error in the topp_sampling kernel."); - } - } -} - -__global__ void air_topp_setup_kernel(curandState_t* state, - int64_t* seed, - const int bs) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { - curand_init(static_cast(seed[i]), 0, 0, &state[i]); - } -} - -__global__ void air_topp_setup_kernel(curandState_t* state, - const uint64_t seed, - const uint64_t offset, - const int bs, - const bool need_batch_random) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { - if (need_batch_random) { - curand_init(seed, i, offset, &state[i]); - } else { - curand_init(seed, 0, offset, &state[i]); - } - } -} - -template -__global__ void print_kernel(T* input, int size) { - printf("["); - for (int i = 0; i < size; i++) { - if (i != size - 1) { - printf("%f, ", static_cast(input[i])); - } else { - printf("%f]\n", static_cast(input[i])); - } - } -} - -template -std::vector LaunchTopPSampling(const paddle::Tensor& x, - const paddle::Tensor& ps, - const paddle::optional& threshold, - const paddle::optional& topp_seed, - int seed, - int k, - const std::string& mode) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - auto stream = x.stream(); - const auto& in_dims = x.dims(); - int p_num = ps.numel(); - int bs = in_dims[0]; - int vocab_size = in_dims[1]; - - auto out = paddle::empty({bs, 1}, x.dtype(), x.place()); - auto ids = paddle::empty({bs, 1}, paddle::DataType::INT64, x.place()); - auto topk_ids = paddle::empty({bs, k}, paddle::DataType::INT64, x.place()); - auto topk_scores = paddle::empty({bs, k}, x.dtype(), x.place()); - - auto ps_now = ps.copy_to(ps.place(), false); - auto inds_input = paddle::empty({bs, vocab_size}, paddle::DataType::INT64, x.place()); - auto sorted_out = paddle::empty({bs, vocab_size}, x.dtype(), x.place()); - auto sorted_id = paddle::empty({bs, vocab_size}, paddle::DataType::INT64, x.place()); - - int BlockSize = GetBlockSize(vocab_size); - switch (BlockSize) { - FIXED_BLOCK_DIM(FillIndex<<>>( - inds_input.data(), bs, vocab_size)); - default: - PD_THROW("the input data shape has error in the FillIndex kernel."); - } - int64_t* infer_seed = topp_seed ? const_cast(topp_seed.get().data()) : nullptr; - - curandState_t* states{nullptr}; - - phi::Allocator::AllocationPtr curand_states_buf{nullptr}; - curand_states_buf = phi::memory_utils::Alloc( - x.place(), - bs * sizeof(curandState_t), - phi::Stream(reinterpret_cast(stream))); - states = reinterpret_cast(curand_states_buf->ptr()); - - - uint64_t seed_now = seed; - uint64_t offset = 0; - bool need_batch_random = false; - - if (infer_seed) { - air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, infer_seed, bs); - } else { - if (seed_now == -1) { - need_batch_random = true; - phi::DeviceContext* dev_ctx = phi::DeviceContextPool::Instance().Get(x.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - uint64_t increment = ps.numel() * 4; - auto seed_offset = gen_cuda->IncrementOffset(increment); - seed_now = seed_offset.first; - offset = seed_offset.second; - air_topp_setup_kernel<<<1, 256, 0, stream>>>( - states, seed_now, offset, bs, need_batch_random); - } else { - air_topp_setup_kernel<<<1, 256, 0, stream>>>( - states, seed_now, offset, bs, need_batch_random); - } - } - - auto count_iter = paddle::empty({bs + 1}, paddle::DataType::INT32, x.place()); - auto count_iter_begin = paddle::empty({bs}, paddle::DataType::INT32, x.place()); - AirToppSetCountIter<<<1, 256, 0, stream>>>(count_iter.data(), bs + 1); - - const data_t* threshold_data = nullptr; - if (threshold) { - threshold_data = threshold.get().data(); - } - - constexpr int TopKMaxLength = 2; - constexpr int TopPBeamTopK = 20; - - DispatchKeMatrixTopPBeamTopK( - reinterpret_cast(x.data()), - reinterpret_cast(threshold_data), - states, - reinterpret_cast(ps_now.data()), - ids.data(), - reinterpret_cast(out.data()), - topk_ids.data(), - reinterpret_cast(topk_scores.data()), - vocab_size, - count_iter.data(), - count_iter_begin.data(), - k, - bs, - need_batch_random, - mode, - stream); - - static_assert(std::is_same::value, "air_topp only supports float now!"); - constexpr int BitsPerPass = 11; - constexpr int SAMPLING_BLOCK_SIZE = 512; - constexpr int INIT_BLOCK_SIZE = 1024; - phi::Allocator::AllocationPtr counter_ptr{nullptr}; - counter_ptr = phi::memory_utils::Alloc( - x.place(), - bs * sizeof(Counter), - phi::Stream(reinterpret_cast(stream))); - Counter *counters = reinterpret_cast*>(counter_ptr->ptr()); - constexpr int numBuckets = calcNumBuckets(); - const int buf_len = calcBufLen(vocab_size); - - auto histograms = paddle::empty({bs, numBuckets}, x.dtype(), x.place()); - auto count_histograms = paddle::empty({bs, numBuckets}, paddle::DataType::INT32, x.place()); - auto buf1 = paddle::empty({bs, bs}, x.dtype(), x.place()); - auto id_buf1 = paddle::empty({bs, buf_len}, paddle::DataType::INT32, x.place()); - auto buf2 = paddle::empty({bs, buf_len}, x.dtype(), x.place()); - auto id_buf2 = paddle::empty({bs, buf_len}, paddle::DataType::INT32, x.place()); - - air_topp_init<<>>( - counters, - reinterpret_cast(histograms.data()), - count_histograms.data(), - reinterpret_cast(x.data()), - reinterpret_cast(ps.data()), - states, - bs, - vocab_size, - buf_len, - numBuckets); - - constexpr int VecSize = 16 / sizeof(data_t); - // TODO: good block_num - const int max_block_num_vocab = ceilDiv(vocab_size, SAMPLING_BLOCK_SIZE * VecSize); - auto kernel = air_topp_sampling; - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, SAMPLING_BLOCK_SIZE, 0); - assert(act_blocks_per_sm > 1); - const int block_per_wave = sm_count * act_blocks_per_sm; - const int block_num_vocab = std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!! - dim3 grid(block_num_vocab, bs); - constexpr int numPasses = calcNumPasses(); - for (int pass = 0; pass < numPasses; ++pass) { - if (pass == 0) { - air_topp_sampling<<>>( - counters, - reinterpret_cast(histograms.data()), - count_histograms.data(), - reinterpret_cast(out.data()), - ids.data(), - reinterpret_cast(buf1.data()), - id_buf1.data(), - reinterpret_cast(buf2.data()), - id_buf2.data(), - count_iter.data(), - count_iter_begin.data(), - buf_len - ); - } else if (pass == 1) { - air_topp_sampling<<>>( - counters, - reinterpret_cast(histograms.data()), - count_histograms.data(), - reinterpret_cast(out.data()), - ids.data(), - reinterpret_cast(buf1.data()), - id_buf1.data(), - reinterpret_cast(buf2.data()), - id_buf2.data(), - count_iter.data(), - count_iter_begin.data(), - buf_len - ); - } else if (pass == 2) { - air_topp_sampling<<>>( - counters, - reinterpret_cast(histograms.data()), - count_histograms.data(), - reinterpret_cast(out.data()), - ids.data(), - reinterpret_cast(buf1.data()), - id_buf1.data(), - reinterpret_cast(buf2.data()), - id_buf2.data(), - count_iter.data(), - count_iter_begin.data(), - buf_len - ); - } else { - PD_THROW("pass must be 0,1 or 2!"); - } - } - return {out, ids}; -} - -std::vector TopPSampling(const paddle::Tensor& x, - const paddle::Tensor& ps, - const paddle::optional& threshold, - const paddle::optional& topp_seed, - int seed, - int k, - const std::string& mode) { - switch (x.type()) { - case paddle::DataType::FLOAT32: { - return LaunchTopPSampling(x, ps, threshold, topp_seed, seed, k, mode); - } - // case paddle::DataType::BFLOAT16: { - // return LaunchTopPSampling(x, ps, threshold, topp_seed, seed, k, mode); - // } - // case paddle::DataType::FLOAT16: { - // return LaunchTopPSampling(x, ps, threshold, topp_seed, seed, k, mode); - // } - default: { - PD_THROW( - "NOT supported data type. Only support float. "); - break; - } - } -} - -std::vector> GetTopPSamplingShape(const std::vector& x_shape, - const std::vector& ps_shape, - const paddle::optional>& threshold_shape, - const paddle::optional>& topp_seed_shape, - int seed, - int k) { - int bs = x_shape[0]; - int vocab_size = x_shape[1]; - return {{bs, 1}, {bs, 1}}; -} - -std::vector GetTopPSamplingDtype(const paddle::DataType& x_dytpe, - const paddle::DataType& ps_dtype, - const paddle::optional& threshold_dtype, - const paddle::optional& topp_seed_dtype, - int seed, - int k) { - return {x_dytpe, paddle::DataType::INT64}; -} - -PD_BUILD_STATIC_OP(air_topp_sampling) - .Inputs({"x", "ps", paddle::Optional("threshold"),paddle::Optional("topp_seed") }) - .Outputs({"out", "ids"}) - .Attrs({"seed: int", "k: int", "mode: std::string"}) - .SetKernelFn(PD_KERNEL(TopPSampling)) - .SetInferShapeFn(PD_INFER_SHAPE(GetTopPSamplingShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(GetTopPSamplingDtype)); \ No newline at end of file diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 1fab55b4e..42bae453e 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -17,15 +17,12 @@ #include "paddle/phi/core/memory/memcpy.h" template -__global__ void GetMaxLenKernel(const int *seq_lens, - const int *seq_lens_this_time, - const int *seq_lens_encoder, - const int *seq_lens_this_time_merged, - const int *seq_lens_encoder_merged, - const int *seq_mapping, - const int *system_lens, - int *max_lens, - const int batch_size) { +__global__ void +GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, + const int *seq_lens_encoder, + const int *seq_lens_this_time_merged, + const int *seq_lens_encoder_merged, const int *seq_mapping, + const int *system_lens, int *max_lens, const int batch_size) { const int tid = threadIdx.x; typedef cub::BlockReduce BlockReduce; @@ -41,43 +38,61 @@ __global__ void GetMaxLenKernel(const int *seq_lens, int max_dec_len_without_system_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; - max_len_this_time_this_thread = max(seq_len_this_time, - max_len_this_time_this_thread); - max_len_encoder_this_thread = max(seq_lens_encoder[i], - max_len_encoder_this_thread); + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread); - if (seq_len_this_time <= 0) continue; + if (seq_len_this_time <= 0) + continue; const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i]; - max_len_this_thread = max(seq_lens[i] + seq_len_this_time, - max_len_this_thread); - max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, - max_just_dec_len_now); + max_len_this_thread = + max(seq_lens[i] + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); if (system_lens) { const int real_bid = seq_mapping[i]; const int system_len_now = system_lens[real_bid]; - max_system_len_this_thread = max(max_system_len_this_thread, system_len_now); - max_dec_len_without_system_this_thread = max(max_dec_len_without_system_this_thread, - max_just_dec_len_now - system_len_now); + max_system_len_this_thread = + max(max_system_len_this_thread, system_len_now); + max_dec_len_without_system_this_thread = + max(max_dec_len_without_system_this_thread, + max_just_dec_len_now - system_len_now); } } if (system_lens) { for (int i = tid; i < batch_size; i += blockDim.x) { const int ori_seq_len_this_time = seq_lens_this_time_merged[i]; - if (ori_seq_len_this_time <= 0) continue; - const int max_just_dec_merged_len_this_time_now = seq_lens_encoder_merged[i] > 0 ? - 0 : ori_seq_len_this_time; - max_just_dec_merged_len_this_time_this_thread = max(max_just_dec_merged_len_this_time_this_thread, - max_just_dec_merged_len_this_time_now); + if (ori_seq_len_this_time <= 0) + continue; + const int max_just_dec_merged_len_this_time_now = + seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time; + max_just_dec_merged_len_this_time_this_thread = + max(max_just_dec_merged_len_this_time_this_thread, + max_just_dec_merged_len_this_time_now); } } - int total_max_len_this_time = BlockReduce(temp_storage).Reduce(max_len_this_time_this_thread, MaxOp()); - int total_max_len_encoder = BlockReduce(temp_storage).Reduce(max_len_encoder_this_thread, MaxOp()); - int total_max_len_decoder = BlockReduce(temp_storage).Reduce(max_len_decoder_this_thread, MaxOp()); - int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); - int total_just_dec = BlockReduce(temp_storage).Reduce(max_just_dec_len_this_thread, MaxOp()); - int total_just_dec_merged = BlockReduce(temp_storage).Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); - int total_system_len = BlockReduce(temp_storage).Reduce(max_system_len_this_thread, MaxOp()); - int total_dec_len_without_system = BlockReduce(temp_storage).Reduce(max_dec_len_without_system_this_thread, MaxOp()); + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_just_dec_merged = + BlockReduce(temp_storage) + .Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); + int total_system_len = BlockReduce(temp_storage) + .Reduce(max_system_len_this_thread, MaxOp()); + int total_dec_len_without_system = + BlockReduce(temp_storage) + .Reduce(max_dec_len_without_system_this_thread, MaxOp()); if (tid == 0) { max_lens[0] = total_max_len_this_time; max_lens[1] = total_max_len_encoder; @@ -90,30 +105,22 @@ __global__ void GetMaxLenKernel(const int *seq_lens, } } -void GetMaxLen(const paddle::Tensor& seq_lens_tensor, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - paddle::Tensor &max_len_tensor, - const int batch_size) { +void GetMaxLen(const paddle::Tensor &seq_lens_tensor, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + paddle::Tensor &max_len_tensor, const int batch_size) { constexpr int blockSize = 1024; GetMaxLenKernel<<<1, blockSize, 0, seq_lens_encoder.stream()>>>( - seq_lens_tensor.data(), - seq_lens_this_time.data(), - seq_lens_encoder.data(), - nullptr, - nullptr, - nullptr, - nullptr, - max_len_tensor.data(), - batch_size); + seq_lens_tensor.data(), seq_lens_this_time.data(), + seq_lens_encoder.data(), nullptr, nullptr, nullptr, nullptr, + max_len_tensor.data(), batch_size); } -__global__ void split_q_block(const int* __restrict__ seq_lens_q, - const int* __restrict__ seq_lens_encoder, - int* __restrict__ batch_ids, - int* __restrict__ tile_ids_per_batch, - int* __restrict__ num_blocks_x, - const int bsz, +__global__ void split_q_block(const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + int *__restrict__ batch_ids, + int *__restrict__ tile_ids_per_batch, + int *__restrict__ num_blocks_x, const int bsz, const int num_rows_per_block, const int group_size) { if (threadIdx.x == 0) { @@ -124,8 +131,7 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q, if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { seq_len = 0; } - const int loop_times = - div_up(seq_len * group_size, num_rows_per_block); + const int loop_times = div_up(seq_len * group_size, num_rows_per_block); for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { batch_ids[index] = bid; tile_ids_per_batch[index++] = tile_id; @@ -136,14 +142,12 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q, } } -__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder, - const int* __restrict__ seq_lens_encoder, - int* __restrict__ batch_ids, - int* __restrict__ tile_ids_per_batch, - int* __restrict__ num_blocks_x, - const int bsz, - const int pad_len, - const int num_row_per_block) { +__global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, + const int *__restrict__ seq_lens_encoder, + int *__restrict__ batch_ids, + int *__restrict__ tile_ids_per_batch, + int *__restrict__ num_blocks_x, const int bsz, + const int pad_len, const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; @@ -165,50 +169,46 @@ __global__ void split_kv_block(const int* __restrict__ seq_lens_decoder, } template -__global__ void get_max_len_kv_ernel(int* max_seq_lens_out, - const int* seq_lens_this_time, - const int* seq_lens_decoder, - const int batch_size) { +__global__ void +get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, + const int *seq_lens_decoder, const int batch_size) { const int tid = threadIdx.x; - typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int max_len_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { - if (seq_lens_decoder[i] == 0) continue; - max_len_this_thread = max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread); + if (seq_lens_decoder[i] == 0) + continue; + max_len_this_thread = + max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread); } - int total = BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); if (tid == 0) { *max_seq_lens_out = total; } } std::vector GetBlockShapeAndSplitKVBlock( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& cum_offsets, - const int encoder_block_shape_q, - const int decoder_block_shape_q, - const int group_size, - const int block_size, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cum_offsets, + const int encoder_block_shape_q, const int decoder_block_shape_q, + const int group_size, const int block_size, const int decoder_step_token_num) { auto stream = seq_lens_encoder.stream(); int bsz = cum_offsets.shape()[0]; auto max_len_tensor = GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place()); - GetMaxLen( - seq_lens_decoder, - seq_lens_this_time, - seq_lens_encoder, - max_len_tensor, - bsz); + GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, + max_len_tensor, bsz); - // max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time, - // max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system + // max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, + // max_enc_dec_len_this_time, max_just_dec_len_this_time, + // max_just_dec_merged_len_this_time, max_system_len, + // max_just_dec_len_without_system auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false); auto max_len_cpu_ptr = max_len_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -229,67 +229,67 @@ std::vector GetBlockShapeAndSplitKVBlock( paddle::Tensor decoder_batch_ids; paddle::Tensor decoder_tile_ids_per_batch; paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor max_len_kv_cpu; /*cpu*/ + paddle::Tensor max_len_kv_cpu; /*cpu*/ auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( - max_len_kv.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - bsz - ); + max_len_kv.data(), seq_lens_this_time.data(), + seq_lens_decoder.data(), bsz); - max_len_kv_cpu = - max_len_kv.copy_to(paddle::CPUPlace(), false); + max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); if (max_enc_len_this_time > 0) { - const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); - kv_batch_ids = GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, - paddle::DataType::INT32, - seq_lens_encoder.place()); - kv_tile_ids_per_batch = GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, - paddle::DataType::INT32, - seq_lens_encoder.place()); + const uint32_t max_tile_size_per_bs_kv = + div_up(max_enc_dec_len_this_time, block_size); + kv_batch_ids = + GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, + seq_lens_encoder.place()); + kv_tile_ids_per_batch = + GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, + seq_lens_encoder.place()); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( - seq_lens_decoder.data(), - // sequence_lengths->data(), - seq_lens_encoder.data(), - kv_batch_ids.data(), - kv_tile_ids_per_batch.data(), - kv_num_blocks_x.data(), - bsz, - block_size, - block_size - ); + seq_lens_decoder.data(), + // sequence_lengths->data(), + seq_lens_encoder.data(), kv_batch_ids.data(), + kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), bsz, + block_size, block_size); kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false); - const uint32_t encoder_max_tile_size_per_bs_q = div_up( - (max_enc_dec_len_this_time * group_size), encoder_block_shape_q); + const uint32_t encoder_max_tile_size_per_bs_q = + div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); encoder_batch_ids = GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, - seq_lens_encoder.place()); + paddle::DataType::INT32, seq_lens_encoder.place()); encoder_tile_ids_per_batch = GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, - seq_lens_encoder.place()); + paddle::DataType::INT32, seq_lens_encoder.place()); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), - nullptr, + split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, encoder_batch_ids.data(), encoder_tile_ids_per_batch.data(), - encoder_num_blocks_x.data(), - bsz, - encoder_block_shape_q, - group_size); + encoder_num_blocks_x.data(), bsz, + encoder_block_shape_q, group_size); encoder_num_blocks_x_cpu = encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); + } else { + encoder_batch_ids = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + encoder_tile_ids_per_batch = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + encoder_num_blocks_x_cpu = + GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); + kv_batch_ids = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + kv_tile_ids_per_batch = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + kv_num_blocks_x_cpu = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); } if (max_just_dec_len_this_time > 0) { const uint32_t decoder_max_tile_size_per_bs_q = @@ -297,24 +297,26 @@ std::vector GetBlockShapeAndSplitKVBlock( decoder_batch_ids = GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, - seq_lens_encoder.place()); + paddle::DataType::INT32, seq_lens_encoder.place()); decoder_tile_ids_per_batch = GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, - seq_lens_encoder.place()); + paddle::DataType::INT32, seq_lens_encoder.place()); auto decoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data(), - seq_lens_encoder.data(), - decoder_batch_ids.data(), - decoder_tile_ids_per_batch.data(), - decoder_num_blocks_x.data(), - bsz, - decoder_block_shape_q, - group_size); + split_q_block<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), seq_lens_encoder.data(), + decoder_batch_ids.data(), decoder_tile_ids_per_batch.data(), + decoder_num_blocks_x.data(), bsz, decoder_block_shape_q, + group_size); decoder_num_blocks_x_cpu = decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); + } else { + decoder_batch_ids = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + decoder_tile_ids_per_batch = + GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); + decoder_num_blocks_x_cpu = + GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); } return {encoder_batch_ids, @@ -331,28 +333,22 @@ std::vector GetBlockShapeAndSplitKVBlock( } std::vector GetBlockShapeAndSplitKVBlockInferDtype( - const paddle::DataType& seq_lens_encoder_dtype, - const paddle::DataType& seq_lens_decoder_dtype, - const paddle::DataType& seq_lens_this_time_dtype, - const paddle::DataType& cum_offsets_dtype) { - return {paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32, - paddle::DataType::INT32}; + const paddle::DataType &seq_lens_encoder_dtype, + const paddle::DataType &seq_lens_decoder_dtype, + const paddle::DataType &seq_lens_this_time_dtype, + const paddle::DataType &cum_offsets_dtype) { + return { + paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, + paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, + paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, + paddle::DataType::INT32, paddle::DataType::INT32}; } std::vector> GetBlockShapeAndSplitKVBlockInferShape( - const std::vector& seq_lens_encoder_shape, - const std::vector& seq_lens_decoder_shape, - const std::vector& seq_lens_this_time_shape, - const std::vector& cum_offsets_shape) { + const std::vector &seq_lens_encoder_shape, + const std::vector &seq_lens_decoder_shape, + const std::vector &seq_lens_this_time_shape, + const std::vector &cum_offsets_shape) { std::vector dynamic_shape = {-1}; return {dynamic_shape, @@ -369,9 +365,7 @@ std::vector> GetBlockShapeAndSplitKVBlockInferShape( } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) - .Inputs({"seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", + .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", "cum_offsets"}) .Outputs({paddle::Optional("encoder_batch_ids"), paddle::Optional("encoder_tile_ids_per_batch"), @@ -382,12 +376,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) paddle::Optional("decoder_batch_ids"), paddle::Optional("decoder_tile_ids_per_batch"), paddle::Optional("decoder_num_blocks"), - paddle::Optional("max_len_kv"), - "set_max_lengths"}) - .Attrs({"encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "group_size: int", - "block_size: int", + paddle::Optional("max_len_kv"), "set_max_lengths"}) + .Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", + "group_size: int", "block_size: int", "decoder_step_token_num: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index a0a095597..5be300177 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -337,6 +337,8 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (deal_each_time == 64) { \ constexpr size_t DEAL_EACH_TIME = 64; \ __VA_ARGS__ \ + } else { \ + PD_THROW("not support the deal_each_time", deal_each_time); \ } #define DISPATCH_NUM_THREADS(num_threads, NUM_THREADS, ...) \ @@ -346,6 +348,8 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (num_threads == 256) { \ constexpr size_t NUM_THREADS = 256; \ __VA_ARGS__ \ + } else { \ + PD_THROW("not support the num_threads", num_threads); \ } #define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ @@ -376,6 +380,11 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (group_size == 12) { \ constexpr size_t GROUP_SIZE = 12; \ __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ } #define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ diff --git a/custom_ops/gpu_ops/cpp_extensions.cu b/custom_ops/gpu_ops/cpp_extensions.cc similarity index 61% rename from custom_ops/gpu_ops/cpp_extensions.cu rename to custom_ops/gpu_ops/cpp_extensions.cc index dbd195098..60920b629 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cu +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/extension.h" - +#include "pybind11/pybind11.h" namespace py = pybind11; // 自定义异常类,用于处理CUDA错误 @@ -125,45 +125,40 @@ paddle::Tensor FusedExpertMoeFunc( const bool norm_topk_prob, const bool group_moe); std::vector MoeExpertDispatch( - const paddle::Tensor& input, - const paddle::Tensor& gating_output, - const paddle::optional& gating_correction_bias, - const paddle::optional &w4a8_in_scale, - const int moe_topk, - const bool group_moe, - const bool topk_only_mode); + const paddle::Tensor &input, const paddle::Tensor &gating_output, + const paddle::optional &gating_correction_bias, + const paddle::optional &w4a8_in_scale, const int moe_topk, + const bool group_moe, const bool topk_only_mode); std::vector MoETopKSelectKernel(const paddle::Tensor &gating_logits, - const paddle::optional &bias, - const int moe_topk, const bool apply_norm_weight, - const bool enable_softmax_top_k_fused); + const paddle::optional &bias, + const int moe_topk, const bool apply_norm_weight, + const bool enable_softmax_top_k_fused); -std::vector MoERedundantTopKSelectKernel( - const paddle::Tensor& gating_logits, - const paddle::Tensor& expert_id_to_ep_rank_array, - const paddle::Tensor& expert_in_rank_num_list, - paddle::Tensor& tokens_per_expert_stats_list, - const paddle::optional& bias, - const int moe_topk, - const bool apply_norm_weight, - const bool enable_softmax_top_k_fused, - const int redundant_ep_rank_num_plus_one); +std::vector +MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits, + const paddle::Tensor &expert_id_to_ep_rank_array, + const paddle::Tensor &expert_in_rank_num_list, + paddle::Tensor &tokens_per_expert_stats_list, + const paddle::optional &bias, + const int moe_topk, const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one); std::vector EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids, - const paddle::Tensor &topk_weights, - const paddle::optional &ffn1_in_scale, - const std::vector &token_nums_per_expert, - const int token_nums_this_rank, - const std::string &moe_quant_type); + const paddle::Tensor &topk_weights, + const paddle::optional &ffn1_in_scale, + const std::vector &token_nums_per_expert, + const int token_nums_this_rank, + const std::string &moe_quant_type); std::vector EPMoeExpertDispatchFP8( const paddle::Tensor &input, const paddle::Tensor &scale, const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, - const std::vector &token_nums_per_expert, - const std::vector &token_nums_per_expert_padded, - const int token_nums_this_rank, const int token_nums_this_rank_padded); + const paddle::Tensor &token_nums_per_expert, + const paddle::Tensor &token_nums_per_expert_padded); std::vector PerTokenQuant(paddle::Tensor &input, const int block_size); @@ -180,20 +175,35 @@ std::vector EPMoeExpertCombine( const paddle::optional &ffn2_bias, const bool norm_topk_prob, const float routed_scaling_factor); -std::vector> GetExpertTokenNum( - const paddle::Tensor& topk_ids, - const int num_experts); +std::vector> GetExpertTokenNum(const paddle::Tensor &topk_ids, + const int num_experts); paddle::Tensor MoeExpertFFNFunc( - const paddle::Tensor &permute_input, - const paddle::Tensor &tokens_expert_prefix_sum, - const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight, - const paddle::optional &ffn1_bias, - const paddle::optional &ffn1_scale, - const paddle::optional &ffn2_scale, - const paddle::optional &ffn2_in_scale, - const paddle::optional &expert_idx_per_token, - const std::string &quant_method, const bool used_in_ep_low_latency); + const paddle::Tensor& permute_input, + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_scale, + const paddle::optional& ffn2_in_scale, + const paddle::optional& expert_idx_per_token, + const std::string& quant_method, const bool used_in_ep_low_latency); + +paddle::Tensor MoeExpertFFNWint2Func( + const paddle::Tensor& permute_input, + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_scale, + const paddle::optional& ffn1_local_scale, + const paddle::optional& ffn1_code_scale, + const paddle::optional& ffn1_code_zp, + const paddle::optional& ffn2_local_scale, + const paddle::optional& ffn2_code_scale, + const paddle::optional& ffn2_code_zp, + const bool used_in_ep_low_latency); paddle::Tensor MoeExpertReduceFunc( const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, @@ -205,19 +215,16 @@ paddle::Tensor MoeExpertReduceFunc( void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, const paddle::Tensor &seq_lens_this_time_tensor, const paddle::Tensor &seq_lens_decoder_tensor, - const int rank, - const int num_layers); + const int rank, const int num_layers); -void GetOutputKVSignal(const paddle::Tensor& x, - int64_t rank_id, +void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, bool wait_flag); - paddle::Tensor DequantInt8Func(const paddle::Tensor &input, const paddle::Tensor &out_scale, std::string dtype); -paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, +paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id, const bool keep_pd_step_flag); paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata, @@ -286,61 +293,121 @@ std::vector ExtractTextTokenOutput( const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text); -std::vector MoEDeepGEMMPermute( - const paddle::Tensor& x, - const paddle::Tensor& topk_idx, - const int num_experts, - const int max_tokens_per_expert -); +std::vector MoEDeepGEMMPermute(const paddle::Tensor &x, + const paddle::Tensor &topk_idx, + const int num_experts, + const int max_tokens_per_expert); std::vector MoEDeepGEMMDePermute( - const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden] - const paddle::Tensor& permute_indices_per_token, // [token_num, topk}] - const paddle::Tensor& topk_idx, - const paddle::Tensor& topk_weights -); + const paddle::Tensor + &ffn_out, // [num_experts, max_tokens_per_expert, hidden] + const paddle::Tensor &permute_indices_per_token, // [token_num, topk}] + const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights); + +void TextImageIndexOut(const paddle::Tensor &token_type_ids, + const paddle::Tensor &text_input, + const paddle::Tensor &image_input); + +void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input, + paddle::Tensor &image_input, + paddle::Tensor &token_type_ids, + paddle::Tensor &text_index, + paddle::Tensor &image_index, const bool is_scatter); + +paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids, + int64_t num_experts); +std::vector tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M); + + +std::vector MoeWna16MarlinGemmApi( + const paddle::Tensor& a, + const paddle::optional& c_or_none, + const paddle::Tensor& b_q_weight, + const paddle::Tensor& b_scales, + const paddle::optional& global_scale_or_none, + const paddle::optional& b_zeros_or_none, + const paddle::optional& g_idx_or_none, + const paddle::optional& perm_or_none, + const paddle::Tensor& workspace, + const paddle::Tensor& sorted_token_ids, + const paddle::Tensor& expert_ids, + const paddle::Tensor& num_tokens_post_padded, + const paddle::Tensor& topk_weights, + int64_t moe_block_size, + int64_t top_k, + bool mul_topk_weights, + bool is_ep, + const std::string& b_q_type_str, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float); +void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a, + paddle::Tensor const &b, paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias); + +void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::Tensor const& azp_adj, + paddle::optional const& azp, + paddle::optional const& bias); + +void StaticScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, + paddle::Tensor const &scale); + +void DynamicScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, + paddle::Tensor &scale); + +void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, + paddle::Tensor const &input, + paddle::Tensor &scales, float scale_ub); PYBIND11_MODULE(fastdeploy_ops, m) { - m.def("get_expert_token_num", &GetExpertTokenNum, - py::arg("topk_ids"), py::arg("num_experts"), - "get expert token num"); + m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), + py::arg("num_experts"), "get expert token num"); + /** + * moe/fused_moe/moe_redundant_topk_select.cu + * moe_redundant_topk_select + */ + m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel, + py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"), + py::arg("expert_in_rank_num_list"), + py::arg("tokens_per_expert_stats_list"), py::arg("bias"), + py::arg("moe_topk"), py::arg("apply_norm_weight"), + py::arg("enable_softmax_top_k_fused"), + py::arg("redundant_ep_rank_num_plus_one"), + "moe export RedundantTopKSelect function"); - /** - * moe/fused_moe/moe_redundant_topk_select.cu - * moe_redundant_topk_select - */ - m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel, - py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"), - py::arg("expert_in_rank_num_list"), py::arg("tokens_per_expert_stats_list"), - py::arg("bias"), py::arg("moe_topk"), py::arg("apply_norm_weight"), - py::arg("enable_softmax_top_k_fused"), py::arg("redundant_ep_rank_num_plus_one"), - "moe export RedundantTopKSelect function"); + /** + * open_shm_and_get_meta_signal.cc + * InitKVSignalPerQuery + */ + m.def("init_kv_signal_per_query", &InitKVSignalPerQuery, + py::arg("seq_lens_encoder_tensor"), + py::arg("seq_lens_this_time_tensor"), + py::arg("seq_lens_decoder_tensor"), py::arg("rank"), + py::arg("num_layers"), "init_kv_signal_per_query function"); + /** + * GetOutputKVSignal + */ + m.def("get_output_kv_signal", &GetOutputKVSignal, py::arg("x"), + py::arg("rank_id"), py::arg("wait_flag"), + "get_output_kv_signal function"); - /** - * open_shm_and_get_meta_signal.cc - * InitKVSingnalPerQuery - */ - m.def("init_kv_signal_per_query", &InitKVSignalPerQuery, - py::arg("seq_lens_encoder_tensor"), py::arg("seq_lens_this_time_tensor"), - py::arg("seq_lens_decoder_tensor"), py::arg("rank"), py::arg("num_layers"), - "init_kv_signal_per_query function"); - - /** - * GetOutputKVSignal - */ - m.def("get_output_kv_signal", &GetOutputKVSignal, - py::arg("x"), py::arg("rank_id"), py::arg("wait_flag"), - "get_output_kv_signal function"); - - - - m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute"); - m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute"); + m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute"); + m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, + "MoEDeepGEMMDePermute"); /** * alloc_cache_pinned.cc * cuda_host_alloc @@ -398,12 +465,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"), py::arg("moe_quant_type"), "ep moe export dispatch function"); - m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8, py::arg("input"), - py::arg("scale"), py::arg("topk_ids"), py::arg("topk_weights"), - py::arg("token_nums_per_expert"), - py::arg("token_nums_per_expert_padded"), - py::arg("token_nums_this_rank"), py::arg("token_nums_this_rank_padded"), - "ep moe export dispatch function"); + m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8); m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"), py::arg("expert_scales_float"), py::arg("permute_indices_per_token"), @@ -437,6 +499,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) { */ m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function"); + /** + * moe/fused_moe/moe_ffn_wint2.cu + * moe_expert_ffn_wint2 + */ + m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function"); + /** * moe/fused_moe/moe_expert_reduce.cu * moe_expert_reduce @@ -523,4 +591,66 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("group_swiglu_with_masked", &GroupSwigluWithMasked, "group_swiglu_with_masked function"); + + m.def("text_image_index_out", &TextImageIndexOut, + "text_image_index_out function"); + + m.def("text_image_gather_scatter", &TextImageGatherScatter, + "text_image_gather_scatter function"); + + m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); + m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); + + m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, + py::arg("a"), + py::arg("c_or_none"), + py::arg("b_q_weight"), + py::arg("b_scales"), + py::arg("global_scale_or_none"), + py::arg("b_zeros_or_none"), + py::arg("g_idx_or_none"), + py::arg("perm_or_none"), + py::arg("workspace"), + py::arg("sorted_token_ids"), + py::arg("expert_ids"), + py::arg("num_tokens_post_padded"), + py::arg("topk_weights"), + py::arg("moe_block_size"), + py::arg("top_k"), + py::arg("mul_topk_weights"), + py::arg("is_ep"), + py::arg("b_q_type_str"), + py::arg("size_m"), + py::arg("size_n"), + py::arg("size_k"), + py::arg("is_k_full"), + py::arg("use_atomic_add"), + py::arg("use_fp32_reduce"), + py::arg("is_zp_float")); + + + /** + * cutlass_scaled_mm.cu + * cutlass_scaled_mm + * cutlass_scaled_mm_azp + */ + m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function"); + m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function"); + + /** + * quantization/common.cu + * static_scaled_fp8_quant + * dynamic_scaled_fp8_quant + * dynamic_per_token_scaled_fp8_quant + */ + m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function", + py::arg("out"), py::arg("input"), py::arg("scale")); + + m.def("dynamic_scaled_fp8_quant", &DynamicScaledFp8Quant, + "dynamic_scaled_fp8_quant function", + py::arg("out"), py::arg("input"), py::arg("scale")); + + m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, + "dynamic_per_token_scaled_fp8_quant function", + py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); } diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h new file mode 100644 index 000000000..a9975c013 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Architecture-specific operators on memory added for SM80 +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/arch/cache_operation.h" + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initiates an asynchronous copy from global memory to shared memory. +/// +/// cp.async +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always, + bool GlobalToShared = true> +struct copy; + +/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate +/// the entire transfer, zeros are written to SMEM if the guard predicate is false. +/// +/// cp.async +/// +template < + /// Size of the access in bytes + int SizeInBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always, + bool GlobalToShared = true> +struct copy_zfill; + +/// Blocks until all but previous cp.async.commit_group operations have committed. +/// +/// cp.async +/// +template +struct copy_wait; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy { + + /// Copy + CUTLASS_DEVICE + copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + cp_async(smem_ptr, global_ptr, pred_guard); + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy { + + /// Copy + CUTLASS_DEVICE + copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + cp_async_zfill(smem_ptr, global_ptr, pred_guard); + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy { + + /// Copy + CUTLASS_DEVICE + copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + cp_async(smem_ptr, global_ptr, pred_guard); + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy { + + /// Copy + CUTLASS_DEVICE + copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + cp_async_zfill(smem_ptr, global_ptr, pred_guard); + } +}; + +/// Partial specialization +template < + /// Size of the access in bytes + int SizeInBytes> +struct copy_zfill { + + /// Copy with zero fill + CUTLASS_DEVICE + copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + } +}; + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +template +CUTLASS_DEVICE +void copy_fence() {} + +template <> +CUTLASS_DEVICE +void copy_fence() { + cp_async_fence(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization +template +struct copy_wait { + + CUTLASS_DEVICE + copy_wait() {} +}; + +/// Partial specialization +template +struct copy_wait { + + CUTLASS_DEVICE + copy_wait() { cp_async_wait(); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp new file mode 100644 index 000000000..3e5aa4b03 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -0,0 +1,460 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// + +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp + +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcastArray { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + const Element* const* ptr_row_array = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, + int group, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , group(group) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + int group; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row_array[group])); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + l, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcastArray { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + const Element* const* ptr_col_array = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + int group, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + group(group), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + int group; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col_array[group])); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + l, + params + ); + } +}; + +} diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 000000000..fa1df1fb1 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// + +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp + +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 000000000..7b56c3c1a --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// + +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp + +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 000000000..513d3741f --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,327 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + +#pragma once + +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace fastdeploy::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { +protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(paddle::Tensor const &tensor) { + using Arguments = typename Descriptor::Arguments; + auto *data_ptr = static_cast(const_cast( + tensor.data())); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } + else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(paddle::optional const &tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto *data_ptr = + tensor ? static_cast(const_cast(tensor->data())) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + paddle._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { +protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType + prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType + prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, paddle::Tensor const &azp, + paddle::optional const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +}; // namespace fastdeploy::c2x diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 000000000..38a51d914 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,453 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp + +#pragma once + +// clang-format will break include orders +// clang-format off +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" +// clang-format on + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace fastdeploy::c3x { + +using namespace cute; + +template struct identity { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { return lhs; } +}; + +template +struct TrivialEpilogue { +private: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using Compute = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + template static ArgumentType prepare_args(Args... args) { + return {}; + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { +protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; + + template + using ColOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoadArray = + cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(paddle::Tensor const &tensor) { + using Arguments = typename Descriptor::Arguments; + auto *data_ptr = static_cast(const_cast(tensor.data())); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(paddle::optional const &tensor) { + using Arguments = typename Descriptor::Arguments; + auto *data_ptr = + tensor ? static_cast(const_cast(tensor->data())) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } + + template + static auto args_from_tensor(const T *const *data_ptr, bool do_broadcast) { + using Arguments = typename Descriptor::Arguments; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr, do_broadcast}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + paddle.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogueBias, but the + * bias is a column vector instead of a row vector. Useful e.g. if we are + * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. + */ +template +struct ScaledEpilogueColumnBias + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template ColLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, bias_args, {}}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType + prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_azp_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType + prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, paddle::Tensor const &azp, + paddle::optional const &bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{ + b_args, evt_acc_args, {}}; + return ArgumentType{a_args, evt_scale_b_args, bias_args, {}}; + } +}; + +/* + This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers + to arrays containing different scales used in group gemm. The number of + pointers in ScaleA and the number of pointers in ScaleB are equal to the + group size. +*/ +template +struct ScaledEpilogueArray + : private ScaledEpilogueBase { +private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoadArray; + using ScaleB = typename SUPER::template RowOrScalarLoadArray; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + +public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + using ScaleAArray = typename SUPER::template ColOrScalarLoadArray; + using ScaleBArray = typename SUPER::template RowOrScalarLoadArray; + + static ArgumentType prepare_args(float const *const *a_scales_ptr, + float const *const *b_scales_ptr, + bool a_col_broadcast, bool b_row_broadcast) { + auto a_args = SUPER::template args_from_tensor( + a_scales_ptr, a_col_broadcast); + auto b_args = SUPER::template args_from_tensor( + b_scales_ptr, b_row_broadcast); + + typename EVTCompute0::Arguments evt0_args{b_args, {}, {}}; + return ArgumentType{a_args, evt0_args, {}}; + } +}; + +}; // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl new file mode 100644 index 000000000..af237b01f --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem +// capacity, or overrides with manual count. +template +constexpr int compute_stage_count_or_override_gated( + StageCountAutoCarveout stage_count) { + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int stage_bytes = [&]() -> int { + if constexpr (SwapAB) { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / + 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + stage_barrier_bytes; + } else { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / + 8 + + stage_barrier_bytes; + } + }(); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated< + arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A()>> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm = + (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only " + "compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, + tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, + tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = + detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v || + IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = + decltype(detail::ss_smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = + decltype(detail::ss_smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = + detail::compute_stage_count_or_override_gated< + detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, + TileShape_MNK, SwapAB>(StageCountType{}); + using DispatchPolicy = cute::conditional_t< + IsArrayOfPointersGemm, + MainloopSm90ArrayTmaGmmaWarpSpecialized, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t< + IsFP8Input, + MainloopSm90TmaGmmaWarpSpecializedFP8< + PipelineStages, ClusterShape_MNK, KernelScheduleType>, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated< + DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, + SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB, + SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated< + arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v< + KernelScheduleType, + KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert( + detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert( + !detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = + detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm = + (cute::is_same_v< + KernelScheduleType, + KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>); + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v || + IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = + decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = + decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = + detail::compute_stage_count_or_override_gated< + detail::sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, + SwapAB>(StageCountType{}); + using DispatchPolicy = cute::conditional_t< + IsArrayOfPointersGemm, + MainloopSm90ArrayTmaGmmaWarpSpecialized, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated< + DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, + SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB, + SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp new file mode 100644 index 000000000..227aee50f --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -0,0 +1,60 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false, class Enable = void> +struct CollectiveBuilderGated { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp new file mode 100644 index 000000000..56849ee56 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false> +struct CollectiveMmaGated { + static_assert(cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 000000000..8ff14a2a4 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,713 @@ +/*************************************************************************************************** + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated< + MainloopSm90TmaGmmaWarpSpecialized, + TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + SwapAB_> { + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for " + "this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any + // rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = + cute::conditional_t>>; + using InternalElementB = + cute::conditional_t>>; + using InternalElementAux = + cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> + smem_A; + cute::array_aligned> + smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const *ptr_A; + StrideA dA; + ElementB const *ptr_B0; + ElementB const *ptr_B1; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B0 = reinterpret_cast(args.ptr_B0); + + Tensor tensor_a = + make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = + make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + if constexpr (SwapAB) { + auto ptr_Aux = reinterpret_cast(args.ptr_B1); + Tensor tensor_aux = + make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy( + GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>( + ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, + args.scale_d1}; + } else { + auto ptr_Aux = reinterpret_cast(args.ptr_B1); + Tensor tensor_aux = + make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy( + GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>( + ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, + args.scale_d1}; + } + } + + template + static bool can_implement(ProblemShape const &problem_shape, + [[maybe_unused]] Arguments const &args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(N, K, L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * + static_cast(sizeof_bits::value)) / + 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * + static_cast(sizeof_bits::value)) / + 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * + static_cast(sizeof_bits::value)) / + 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const &mainloop_params) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the + /// contract Returned tuple must contain at least two elements, with the first + /// two elements being: gA_mkl - The tma tensor, A after a local tile so it + /// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local + /// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B + /// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) The rest of the + /// tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL, + Params const &mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain + // mapping Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } else { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void + load(Params const &mainloop_params, MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), + SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = + get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, + block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = + mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = + mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = + SwapAB + ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice( + cluster_local_block_id.x); + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) + : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = + Layout{}; // (m,n) -> + // block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, + n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = + Layout{}; // (m,n) -> + // block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout( + m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) { + mcast_mask_aux = mcast_mask_a; + } else { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), + tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, + FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, + TensorStorage &shared_tensors, Params const &mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAux{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), + SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto { + if constexpr (SwapAB) { + return thread_mma.partition_A(sAux); + } else { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto { + if constexpr (SwapAB) { + return thread_mma.make_fragment_A(tCsAux); + } else { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == + size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; + --k_tile_prologue) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accum1); + } else { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accum1); + } else { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to + /// ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 000000000..76ffbdb2e --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,724 @@ +/*************************************************************************************************** + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated< + MainloopSm90TmaGmmaWarpSpecializedFP8, + TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + SwapAB_> { + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for " + "this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> + smem_A; + cute::array_aligned> + smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const *ptr_A; + StrideA dA; + ElementB const *ptr_B0; + ElementB const *ptr_B1; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, 0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, 0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B0 = reinterpret_cast(args.ptr_B0); + + Tensor tensor_a = + make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = + make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + if constexpr (SwapAB) { + auto ptr_Aux = reinterpret_cast(args.ptr_B1); + Tensor tensor_aux = + make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy( + GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>( + ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, + args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } else { + auto ptr_Aux = reinterpret_cast(args.ptr_B1); + Tensor tensor_aux = + make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy( + GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>( + ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, + args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + } + + template + static bool can_implement(ProblemShape const &problem_shape, + [[maybe_unused]] Arguments const &args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(N, K, L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop + * iteration would issue 4 MMA instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * + static_cast(sizeof_bits::value)) / + 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * + static_cast(sizeof_bits::value)) / + 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * + static_cast(sizeof_bits::value)) / + 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const &mainloop_params) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the + /// contract Returned tuple must contain at least two elements, with the first + /// two elements being: gA_mkl - The tma tensor, A after a local tile so it + /// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local + /// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B + /// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL, + Params const &mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain + // mapping Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } else { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void + load(Params const &mainloop_params, MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), + SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, + block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = + mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = + mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = + SwapAB + ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice( + cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) + : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = + Layout{}; // (m,n) -> + // block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, + n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = + Layout{}; // (m,n) -> + // block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout( + m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) { + mcast_mask_aux = mcast_mask_a; + } else { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType *tma_barrier = + pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), + tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, + FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, + TensorStorage &shared_tensors, Params const &mainloop_params) { + + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), + SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto { + if constexpr (SwapAB) { + return thread_mma.partition_A(sAux); + } else { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto { + if constexpr (SwapAB) { + return thread_mma.make_fragment_A(tCsAux); + } else { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } else { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == + size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation0( + accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation1( + accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; + --k_tile_prologue) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation0.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accumulation1()); + } else { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation0.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), accumulation1()); + } else { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to + /// ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation0.promote_residue_if_needed(); + accumulation1.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp new file mode 100644 index 000000000..15faad26e --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. + **/ +template +class GemmUniversalGated; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 8ac984faf..40f128b7a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -130,6 +130,15 @@ public: using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + template struct LayoutDetailsB= 90>::type> { diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp new file mode 100644 index 000000000..843529cde --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + cute::enable_if_t && + CollectiveMainloop_::isGated>> { +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or + cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = + CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = + CUTE_STATIC_V(size(TiledMma{})) + + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load + // threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = + typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = + typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void *workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t *workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void *scheduler_workspace = workspace_ptr; + workspace_offset += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void *epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size( + args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void *mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. + // Therefore it will be used in separate reduction scheme for streamk case, + // NumEpilogueSubTiles default value is 1, which means subtile will not be + // used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = + CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, + args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + + return {args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace}; + } + + static bool can_implement(Arguments const &args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && + cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); + return implementable; + } + implementable &= + CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= + CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const &args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = + CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, + args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t *workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = + CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, + args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, + stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size( + args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const ¶ms) { + // Given device SM count, set grid size s.t. we do not launch more thread + // blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = + params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, + TileShape{}, ClusterShape{}, + params.hw_info, args); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert( + size(TiledMma{}) == 256, + "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or " + "equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the + * same tile */ + enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage &shared_storage = + *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Mainloop) { + mainloop_pipeline_params.role = + MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = + MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = + CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = + CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, + epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = + producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, + params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via + // scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = + cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = + cutlass::make_producer_start_state(); + + auto cluster_wait_fn = []() { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } else { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Optionally append 1s until problem shape is rank-4 in case it is only + // rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread + // block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and + // compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, + shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors + // where: get<0>(load_inputs) is the tma tensor A after local tiling so that + // it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B + // after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = + collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert( + cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the + // starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count( + work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = + TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator( + idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load( + params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, + block_rank_in_cluster, shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, + mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && + collective_epilogue.is_producer_load_needed()) { + while (work_tile_info.is_valid()) { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, + params.scheduler)) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, epi_load_pipe_producer_state, + problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, + epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load + // is waiting for it + bool do_store_tail = false; + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count( + work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C( + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C( + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + collective_mainloop.mma( + mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, work_k_tile_count, mma_thread_idx, + shared_storage.tensors.mainloop, params.mainloop); + + // Make sure the math instructions are done and free buffers before + // entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = + canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0, + NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1, + NumMmaWarpGroups, consumer_warp_group_idx); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) { + accumulators0[i] = elt_op(accumulators0[i] * scale_d0) * + (scale_d1 * accumulators1[i]); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, + epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state); + } + } // Consumer Warp Groups End +#endif + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo + fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info, + TileScheduler &scheduler) const { + // Check whether we should continue on with the current work unit. If this + // is the case, the work unit will have been updated in + // continue_current_work to reflect the new tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp new file mode 100644 index 000000000..e6cc7de5c --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,680 @@ +/*************************************************************************************************** + * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cute/util/debug.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + cute::enable_if_t && + CollectiveMainloop_::isGated>> { +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or + cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert( + !cute::is_same_v, + "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock = + CUTE_STATIC_V(size(TiledMma{})) + + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load + // threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for + // Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = + cutlass::OrderedSequenceBarrier; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = + typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = + typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = + typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void)workspace; + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t *workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void *scheduler_workspace = workspace_ptr; + workspace_offset += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void *epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size( + args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void *mainloop_workspace = nullptr; + + return {args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, + args.scheduler, scheduler_workspace)}; + } + + static bool can_implement(Arguments const &args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && + cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); + return implementable; + } + implementable &= + CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= + CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const &args) { + size_t workspace_size = 0; + workspace_size += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, + args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t *workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, + args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += + TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, + stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size( + args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const ¶ms) { + // Given device SM count, set grid size s.t. we do not launch more thread + // blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = + params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, + TileShape{}, ClusterShape{}, + params.hw_info, args); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + + enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage &shared_storage = + *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Mainloop) { + mainloop_pipeline_params.role = + MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = + MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = + CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && + producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = + CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, + epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = + producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, + params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = + canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = + NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier( + shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via + // scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = + cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = + cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = + cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&]() { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } else { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case it is only + // rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread + // block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // In a warp specialized kernel, collectives expose data movement and + // compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, + shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors + // where: get<0>(load_inputs) is the tma tensor A after local tiling so that + // it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B + // after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = + collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert( + cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_role == WarpGroupRole::Consumer1) { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + + collective_mainloop.load( + params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, + block_rank_in_cluster, shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, + mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && + collective_epilogue.is_producer_load_needed()) { + load_order_barrier.wait(); + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, epi_load_pipe_producer_state, + problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, + epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and + // n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the accumulators for the (M,N) blk_shape + Tensor accumulators0 = partition_fragment_C( + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + Tensor accumulators1 = partition_fragment_C( + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma( + mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, k_tile_count, warp_group_thread_idx, + shared_storage.tensors.mainloop, params.mainloop); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before + // entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) { + accumulators0[i] = elt_op(accumulators0[i] * scale_d0) * + (scale_d1 * accumulators1[i]); + } + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, + epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, warp_group_thread_idx, + shared_storage.tensors.epilogue); + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for + // multiple-consumer kernels we need to wait for all TMA stores to + // complete before issuing consumer order barrier arrives to ensure next + // math consumer doesn't overwrite smem of in-flight TMA stores of + // current consumer. + auto [epi_load_pipe_consumer_state_next_, + epi_store_pipe_producer_state_next_] = + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, + // advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index ad6c7496e..bc395d04d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -77,6 +77,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) template < /// Layout type for A matrix operand @@ -125,6 +126,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage /// (stage>=3) template < @@ -148,7 +150,7 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator, - /// + /// Number of stages used in the multistage mainloop int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> @@ -179,6 +181,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage /// (stage>=3) template < @@ -234,6 +237,7 @@ public: #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage /// (stage>=3) template < @@ -346,6 +350,131 @@ struct DefaultMma; }; +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Number of stages used in the multistage mainloop + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +}; + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 77af81005..5d2c31170 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -19,13 +19,11 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -197,6 +195,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight template < /// Layout type for A matrix operand @@ -244,6 +243,9 @@ public: using ThreadblockMma = typename Mma::ThreadblockMma; }; +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -265,7 +267,7 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator, - /// + /// Number of stages used in the multistage mainloop int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> @@ -296,6 +298,7 @@ public: }; //////////////////////////////////////////////////////////////////////////////// + /// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight template < /// Layout type for A matrix operand @@ -318,11 +321,11 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator, - /// + /// Number of stages used in the multistage mainloop int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma { @@ -348,6 +351,131 @@ public: using ThreadblockMma = typename Mma::ThreadblockMma; }; +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Number of stages used in the multistage mainloop + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +}; + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h new file mode 100644 index 000000000..6dd55b647 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class Wint2xMmaBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = + GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // using TensorRefZippedB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + // w uint8; local_scale uint8; + constexpr static int kZippedRowsPerStages = + Shape::kK / 4 + (Shape::kK + 127) / 128; + + // code_scale float; code_zp float; super_scale ElementB + constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + + sizeof_bits::value / 8; + + using ZippedShapeB = MatrixShape; + + using NopaddingShapeB = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for quanted B operand + AlignedBuffer operand_zipped_B; + + /// Buffer for unzip B operand + AlignedBuffer + operand_unzip_B; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + CUTLASS_HOST_DEVICE + uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } + + CUTLASS_HOST_DEVICE + typename Operator::ElementB *operand_unzip_B_ptr() { + return operand_unzip_B.data(); + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2xMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h new file mode 100644 index 000000000..38fdcf9fe --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -0,0 +1,807 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/arch/memory_copy_sm80.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class Wint2xMmaMultistage : + public Wint2xMmaBase { +public: + ///< Base class + using Base = Wint2xMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + }; + + private: + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + WarpTransformedFragmentA warp_transformed_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_[2]; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + uint8_t* column_wise_smem_ptr_B_; + + uint8_t* smem_zipped_ptr_B_; + int smem_zipped_bytes_per_stage_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + Wint2xMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + + column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); + + smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; + smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + smem_read_stage_idx_ = 0; + } + this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + template + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + TileDequanterB &tile_dequanter_B) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + //iterator_B.add_tile_offset({1, 0}); + tile_dequanter_B.AddTileOffset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + //smem_iterator_B_.add_tile_offset({1, 0}); + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::copy_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::copy( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + } + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) { + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (InitStage) { + cutlass::arch::copy_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } else { + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::copy_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::copy( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + } + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + __syncthreads(); + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + template + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + TileDequanterB &tile_dequanter_B, + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + // Async copy zipped B to shared memory. + copy_tiles_and_advance_per_stage_A(iterator_A); + + // Async copy zipped B to shared memory. + tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, + column_wise_smem_ptr_B_, stage); + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Optionally clear the remaining stages of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint are zero. + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + typename IteratorA::AccessType zero_A; + + zero_A.clear(); + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + } + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + template + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand + int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining + int stage) + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); + + // Load the next warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + // Unpack and dequant the first stage of B. + int unpack_stage = stage - Base::kStages + 2; + tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, + column_wise_smem_ptr_B_, unpack_stage); + + // Copy dequatized data to shared memory used by mma core. + copy_tiles_and_advance_per_stage_B(iterator_B); + } + + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; + + // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary + if (warp_mma_k > 0) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // Except for the last warp-tile, all warp-tiles issue their share of + // global->shared fragment copies + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + + if (warp_mma_k == 0) { + tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, + column_wise_smem_ptr_B_, stage); + } + } + + // The second-to-last warp-tile also: + // - performs the last warp-tile's share of global->shared fragment copies + // - moves to the next global fetch stage + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Performs the last warp-tile's share of global->shared fragment copies + int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); + } + + // The last warp-tile also converts the shared memory fragments used by + // the first warp-tile of the next iteration, if necessary (so we can + // immediately start issuing MMA instructions at the top of the loop ) + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + } + } + } + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + template + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, + TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory + { + PipeState pipe_state; + + // Unpack and dequant the first stage of B. + tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Copy dequatized data to shared memory used by mma core. + copy_tiles_and_advance_per_stage_B(iterator_B); + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + ++this->warp_tile_iterator_B_; + + // Transform, if necessary, the first warp-tile's shared memory fragments + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[0], + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_A_[0], + pipe_state.warp_loaded_frag_B_[0]); + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + int stage = Base::kStages - 1; + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + tile_dequanter_B, + gemm_k_iterations, + stage); + stage += 1; + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + /// Prepares the class for another prologue. + CUTLASS_DEVICE + void wind_down() + { + // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) + + // First, increment remaining warp tiles to get to the next full stage. (Ideally we would + // just decrement one tile, but not all iterators implement --() decrement.) + #pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + smem_read_stage_idx_++; + + // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) + static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; + if (smem_read_stage_idx_ > 1) + { + this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); + //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); + this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + } + smem_read_stage_idx_ = smem_write_stage_idx_; + } + + /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. + template + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< pre-load and dequantize B to shared memory + TileDequanterB tile_dequanter_B, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h new file mode 100644 index 000000000..cec6bcea0 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h @@ -0,0 +1,130 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/gemm_coord.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template +struct TileDequanter { + using WeightQuantTraits = WintQuantTraits; + using MmaElementT = typename WeightQuantTraits::MmaWeightType; + using QuantArguments = typename WeightQuantTraits::Arguments; + + using UnzipAndDequantFunctor = + UnzipAndDequantFunctor; + + static constexpr bool kUseSharedMemory = true; + + static constexpr int kRows = Rows; + static constexpr int kColumns = Columns; + static constexpr int kStages = Stages; + + MmaElementT *out_smem_ptr{nullptr}; + + char *pointer{nullptr}; + int64_t ldm{0}; + cutlass::MatrixCoord tb_offset; + cutlass::MatrixCoord extent; + + ScaleElementT *super_scale_ptr{nullptr}; + cutlass::MatrixCoord tb_offset_scale; + + QuantArguments quant_args; + + int64_t block_start_rows[kStages]; + bool need_preload{true}; + UnzipAndDequantFunctor unzip_functor; + + CUTLASS_DEVICE + TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, + const cutlass::MatrixCoord &extent, + const cutlass::MatrixCoord &tb_offset, + ScaleElementT *super_scale_ptr, + const cutlass::MatrixCoord &tb_offset_scale, + const QuantArguments &quant_args) + : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), + tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), + tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} + + CUTLASS_DEVICE + MmaElementT *GetOutPtr() { return out_smem_ptr; } + + CUTLASS_DEVICE + void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { + tb_offset.row() += tile_offset.row() * kRows; + tb_offset.column() += tile_offset.column() * kColumns; + tb_offset_scale.column() += tile_offset.column() * kColumns; + } + + CUTLASS_DEVICE + void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { + int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); + if (tb_offset.row() >= extent.row() || + tb_offset.column() >= extent.column()) { + return; + } + + block_start_rows[stage % kStages] = tb_offset.row(); + + using ZippedT = typename WeightQuantTraits::WeightType; + ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + + tb_offset.column(); + ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); + + if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { + const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + + (tb_offset.row() / 128) * ldm + + tb_offset_scale.column(); + const float *code_scale_ptr = + quant_args.code_scale_ptr + tb_offset_scale.column(); + const float *code_zp_ptr = + quant_args.code_zp_ptr + tb_offset_scale.column(); + + typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); + unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, + scale_ptr, &args, ldm, need_preload); + need_preload = false; + } else { + // CUTLASS_TRACE_DEVICE("Not Supported!"); + } + } + + CUTLASS_DEVICE + void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { + int64_t block_start_row = block_start_rows[stage % kStages]; + if (block_start_row >= extent.row()) { + return; + } + + if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { + typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); + unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); + } else { + // CUTLASS_TRACE_DEVICE("Not Supported!"); + } + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h new file mode 100644 index 000000000..9d49d5eb5 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h @@ -0,0 +1,447 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/arch/memory.h" +#include "cutlass/trace.h" +#include "cutlass_extensions/wint_type_traits.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template +using UnzipArray = cutlass::AlignedArray::value / 8)>; + +template +struct UnzipAndDequantFunctor { + __device__ void operator()(const T *in_ptr, const T *supper_scale_ptr, + T *out_ptr, const int64_t in_stride) {} +}; + +template +struct UnzipAndDequantFunctor { + using ZippedT = uint16_t; + using ScaleComputeT = float; + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kZippedGroupSize = 10; + static constexpr int32_t kNumPackedValues = 7; + + static constexpr int32_t kWeightMask = 0x7; + static constexpr int32_t kLocalScaleMask = 0x1FFF; + static constexpr int32_t kBZP = 4; + + __device__ inline T Compute(int32_t zipped_value, int32_t shift_bit, + ScaleComputeT scale) { + int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask; + int32_t value = shifted_value - kBZP; + + ScaleComputeT scaled_value = static_cast(value) * scale; + return static_cast(scaled_value); + } + + __device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr, + T *out_ptr, const int64_t in_stride) { + int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0}; + + int tid = threadIdx.x; + +#pragma unroll + for (int col = tid; col < TileColumns; col += NumThreads) { + ScaleComputeT super_scale = + static_cast(super_scale_ptr[col]); + +#pragma unroll + for (int group_id = 0; group_id < TileRows / 64; ++group_id) { + // the last row in group + int zipped_row_last = group_id * 10 + 9; + int zipped_offset_last = zipped_row_last * in_stride + col; + int32_t zipped_value_last = + static_cast(in_ptr[zipped_offset_last]); + + ScaleComputeT local_scale = + static_cast(zipped_value_last & kLocalScaleMask); + ScaleComputeT scale = local_scale * super_scale; + +#pragma unroll + for (int zipped_row_in_group = 0; zipped_row_in_group < 9; + ++zipped_row_in_group) { + int zipped_row = group_id * 10 + zipped_row_in_group; + int zipped_offset = zipped_row * in_stride + col; + int32_t zipped_value = static_cast(in_ptr[zipped_offset]); + + int row_in_group = group_id * 64 + zipped_row_in_group * 7; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + T value = Compute(zipped_value, shift_bit, scale); + out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value; + } + } + + int row_in_group_last = group_id * 64 + 63; + T value_last = Compute(zipped_value_last, shift_bits[0], scale); + out_ptr[row_in_group_last * TileColumns + col] = value_last; + } + } + __syncthreads(); + } +}; + +template +struct UnzipAndDequantFunctor { + using ZippedT = uint8_t; + using ScaleComputeT = float; + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kPackNum = 4; + static constexpr int32_t kWeightMask = 0x3F; + static constexpr int32_t kLocalScaleMask = 0xF; + static constexpr int32_t kBZP = 32; + + // weight [16, N] uint8_t + // local_scale [1, N] uint8_t + // code_scale [N] float + // code_zp [N] float + // super_scale [N] T + + // code_scale, code_zp and super_scale + static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns; + // zipped weights and local_scale + static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns; + + struct Arguments { + uint8_t *weight_ptr; + uint8_t *local_scale_ptr; + float *code_scale_ptr; + float *code_zp_ptr; + T *super_scale_ptr; + + __device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {} + + __device__ explicit Arguments(uint8_t *smem_ptr) { + SetZippedPtrs(smem_ptr); + SetColumnWisePtrs(smem_ptr + kZippedSmemBytes); + } + + __device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) { + SetZippedPtrs(zipped_smem_ptr); + SetColumnWisePtrs(column_wise_smem_ptr); + } + + __device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) { + weight_ptr = zipped_smem_ptr; + local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns; + } + + __device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) { + code_scale_ptr = reinterpret_cast(column_wise_smem_ptr); + code_zp_ptr = reinterpret_cast(column_wise_smem_ptr + sizeof(float) * TileColumns); + super_scale_ptr = reinterpret_cast(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns); + } + }; + + __device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr, + const float *g_code_scale_ptr, const float *g_code_zp_ptr, + const T *g_super_scale_ptr, + Arguments *args, const int64_t in_stride, bool need_preload) { + int tid = threadIdx.x; + +#pragma unroll + for (int col = tid; col < TileColumns; col += NumThreads) { + if (need_preload) { + if (g_super_scale_ptr) { + args->super_scale_ptr[col] = g_super_scale_ptr[col]; + } else { + args->super_scale_ptr[col] = static_cast(1); + } + + args->code_scale_ptr[col] = g_code_scale_ptr[col]; + args->code_zp_ptr[col] = g_code_zp_ptr[col]; + } + +#pragma unroll + for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) { + int local_scale_offset = ls_row_id * in_stride + col; + args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset]; + } + +#pragma unroll + for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) { + int s_zipped_offset = zipped_row * TileColumns + col; + int g_zipped_offset = zipped_row * 4 * in_stride + col; + + args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset]; + } + } + __syncthreads(); + } + + __device__ void LoadAsync(const uint8_t *g_weight_ptr, + const uint8_t *g_local_scale_ptr, + const float *g_code_scale_ptr, + const float *g_code_zp_ptr, + const T *g_super_scale_ptr, + Arguments *args, const int64_t in_stride, bool need_preload) { + int tid = threadIdx.x; + + constexpr int kBytesPerThread = 16; // 16B per thread + + constexpr int weight_size = TileRows / 4 * TileColumns; + constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns; + constexpr int code_scale_size = sizeof(float) * TileColumns; + constexpr int code_zp_size = sizeof(float) * TileColumns; + constexpr int super_scale_size = sizeof(T) * TileColumns; + + constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size; + constexpr int total_tasks = total_size / kBytesPerThread; + + constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads); + + constexpr int weight_threads = weight_size * cur_num_threads / total_size; + constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size; + constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size; + constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size; + constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size; + + static_assert(TileColumns % weight_threads == 0, + "TileColumns must be divisible by weight_threads to ensure correct thread mapping."); + + static_assert(TileColumns % local_scale_threads == 0, + "TileColumns must be divisible by local_scale_threads to ensure correct thread mapping."); + + if (tid < weight_threads) { + constexpr int weight_per_thread_size = weight_size / weight_threads; + constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread); + int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->weight_ptr + z_offset, g_weight_ptr + g_offset, true); + } + } else if (tid < weight_threads + local_scale_threads) { + constexpr int start_thread_id = weight_threads; + constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads; + constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread; + int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true); + } + } else if (need_preload) { + if (tid < weight_threads + local_scale_threads + code_scale_threads) { + constexpr int start_thread_id = weight_threads + local_scale_threads; + constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads; + constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float); + cutlass::arch::cp_async( + args->code_scale_ptr + offset, g_code_scale_ptr + offset, true); + } + } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) { + constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads; + constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads; + constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float); + cutlass::arch::cp_async( + args->code_zp_ptr + offset, g_code_zp_ptr + offset, true); + } + } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) { + if (g_super_scale_ptr) { + constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads; + constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads; + constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kIterations; ++i) { + int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T); + cutlass::arch::cp_async( + args->super_scale_ptr + offset, g_super_scale_ptr + offset, true); + } + } + } + } + } + + __device__ void Compute(const Arguments &args, T *out_ptr, + const int64_t block_start_row) { + int32_t shift_bits[4] = {9, 6, 3, 0}; + + int tid = threadIdx.x; + +#pragma unroll + for (int col = tid; col < TileColumns; col += NumThreads) { + ScaleComputeT super_scale = + static_cast(args.super_scale_ptr[col]); + ScaleComputeT code_scale = + static_cast(args.code_scale_ptr[col]); + ScaleComputeT code_zp = static_cast(args.code_zp_ptr[col]); + +#pragma unroll + for (int group_id = 0; group_id < TileRows / 64; ++group_id) { + int local_scale_offset = (group_id / 2) * TileColumns + col; + int32_t local_scale = + static_cast(args.local_scale_ptr[local_scale_offset]); + + ScaleComputeT zipped_value[16]; + +#pragma unroll + for (int zipped_row = 0; zipped_row < 16; ++zipped_row) { + int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col; + zipped_value[zipped_row] = + static_cast(args.weight_ptr[zipped_offset]); + } + + int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4; + int32_t shifted_local_scale = + (local_scale >> local_scale_shift) & kLocalScaleMask; + ScaleComputeT scale = + static_cast(shifted_local_scale) * super_scale; + +#pragma unroll + for (int zipped_row = 0; zipped_row < 16; ++zipped_row) { + int32_t decode_value = + static_cast(floor(zipped_value[zipped_row] * code_scale + code_zp + + static_cast(0.5))); + + int row = group_id * 64 + zipped_row * 4; + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + int32_t shift_bit = shift_bits[shift_bit_id]; + int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask; + + ScaleComputeT value = + static_cast(shifted_value - kBZP); + out_ptr[(row + shift_bit_id) * TileColumns + col] = + static_cast(scale * value); + } + } + } + } + __syncthreads(); + } + + __device__ void ComputeVectorized(const Arguments &args, T *out_ptr, + const int64_t block_start_row) { + constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads); + constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2; + constexpr int RowStride = NumThreads * N / TileColumns; + constexpr int kNumIters = kNumWeightsPerThread / N; + + static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns."); + + constexpr ScaleComputeT decode_value_zp = static_cast(0.5); + + int tid = threadIdx.x; + int begin_col_id = (tid * N) % TileColumns; + int begin_row_id = (tid * N) / TileColumns; + + static_assert(TileRows <= 128, "TileRows is expected to no more than 128."); + + UnzipArray local_scales = + *reinterpret_cast *>(args.local_scale_ptr + begin_col_id); + + UnzipArray zipped_values[2]; + int zipped_offset = begin_row_id * TileColumns + begin_col_id; + zipped_values[0] = + *reinterpret_cast *>(args.weight_ptr + zipped_offset); + + UnzipArray super_scales = + *reinterpret_cast *>(args.super_scale_ptr + begin_col_id); + UnzipArray code_scales = + *reinterpret_cast *>(args.code_scale_ptr + begin_col_id); + UnzipArray code_zps = + *reinterpret_cast *>(args.code_zp_ptr + begin_col_id); + + // special for TileRows = 64 + int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4; + UnzipArray scales; + +#pragma unroll + for (int i = 0; i < N; ++i) { + int32_t shifted_local_scale = + (static_cast(local_scales[i]) >> local_scale_shift) & kLocalScaleMask; + scales[i] = + static_cast(shifted_local_scale) * static_cast(super_scales[i]); + } + +#pragma unroll + for (int iter_id = 0; iter_id < kNumIters; ++iter_id) { + int zipped_row = begin_row_id + iter_id * RowStride; + int row = zipped_row * 4; + + if (iter_id < kNumIters - 1) { + int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id; + zipped_values[(iter_id + 1) & 1] = + *reinterpret_cast *>(args.weight_ptr + zipped_offset); + } + + UnzipArray outs[4]; + +#pragma unroll + for (int i = 0; i < N; ++i) { + int32_t decode_value = + static_cast(floor(static_cast(zipped_values[iter_id & 1][i]) * code_scales[i] + + code_zps[i] + decode_value_zp)); + + ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + decode_value >>= 3; + ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + outs[0][i] = static_cast(scales[i] * value_0); + outs[1][i] = static_cast(scales[i] * value_1); + outs[2][i] = static_cast(scales[i] * value_2); + outs[3][i] = static_cast(scales[i] * value_3); + } + +#pragma unroll + for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) { + UnzipArray *tmp_out_ptr = reinterpret_cast *>( + out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id); + *tmp_out_ptr = outs[shift_bit_id]; + } + } + __syncthreads(); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h new file mode 100644 index 000000000..9e1c6c463 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -0,0 +1,140 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +enum WintQuantMethod { + kNone = 0, + kWeightOnlyInt8 = 1, + kWeightOnlyInt4 = 2, + kWeightOnlyInt25 = 3, + kWeightOnlyInt2 = 4 +}; + +// Convert CUDA data type to cutlass data type +template struct CutlassDataType { + using Type = T; +}; + +template <> struct CutlassDataType { + using Type = cutlass::half_t; +}; + +template <> struct CutlassDataType<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; + +template struct WintQuantTraits; + +template +struct WintQuantTraits { + using WeightType = ElementT; + using MmaKernelType = typename CutlassDataType::Type; + using MmaWeightType = typename CutlassDataType::Type; + + static constexpr WintQuantMethod kQuantMethod = WintQuantMethod::kNone; + + struct Arguments {}; + + CUTLASS_DEVICE + static int64_t CaclPackedDim(int64_t dim) { return dim; } +}; + +template +struct WintQuantTraits { + using WeightType = uint8_t; + using MmaKernelType = uint8_t; + using MmaWeightType = uint8_t; + + static constexpr WintQuantMethod kQuantMethod = + WintQuantMethod::kWeightOnlyInt8; + + struct Arguments {}; + + CUTLASS_DEVICE + static int64_t CaclPackedDim(int64_t dim) { return dim; } +}; + +template +struct WintQuantTraits { + using WeightType = cutlass::uint4b_t; + using MmaKernelType = cutlass::uint4b_t; + using MmaWeightType = cutlass::uint4b_t; + + static constexpr WintQuantMethod kQuantMethod = + WintQuantMethod::kWeightOnlyInt4; + + struct Arguments {}; + + CUTLASS_DEVICE + static int64_t CaclPackedDim(int64_t dim) { return dim; } +}; + +template +struct WintQuantTraits { + using WeightType = uint16_t; + using MmaKernelType = typename CutlassDataType::Type; + using MmaWeightType = typename CutlassDataType::Type; + + static constexpr WintQuantMethod kQuantMethod = + WintQuantMethod::kWeightOnlyInt25; + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kNumPackedValues = 7; + static constexpr int32_t kPackedSize = 10; + + struct Arguments {}; + + CUTLASS_DEVICE + static int64_t CaclPackedDim(int64_t dim) { + return dim * kPackedSize / kGroupSize; + } +}; + +template +struct WintQuantTraits { + using WeightType = uint8_t; + using MmaKernelType = cutlass::uint2b_t; + using MmaWeightType = typename CutlassDataType::Type; + + static constexpr WintQuantMethod kQuantMethod = + WintQuantMethod::kWeightOnlyInt2; + + static constexpr int32_t kGroupSize = 64; + static constexpr int32_t kNumPackedValues = 4; + static constexpr int32_t kPackedSize = 16; + + struct Arguments { + const uint8_t *local_scale_ptr; // quanted 4-bits + const float *code_scale_ptr; + const float *code_zp_ptr; + }; + + CUTLASS_DEVICE + static int64_t CaclPackedDim(int64_t dim) { + return dim * kPackedSize / kGroupSize; + } +}; + +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h index 695a6d3db..3ac548c62 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h @@ -16,106 +16,127 @@ #include #include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" #include "cutlass/half.h" #include "helper.h" #include "paddle/extension.h" -template -class CutlassDtypeTraits; +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + PD_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } -template <> -class CutlassDtypeTraits { - public: - typedef float DataType; - typedef float data_t; +/** + * A wrapper for a kernel that is used to guard against compilation on + * architectures that will never use the kernel. The purpose of this is to + * reduce the size of the compiled binary. + * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef + * into code that will be executed on the device where it is defined. + */ +template struct enable_sm90_or_later : Kernel { + template CUTLASS_DEVICE void operator()(Args &&...args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } }; -template <> -class CutlassDtypeTraits { - public: - typedef cutlass::half_t DataType; - typedef paddle::float16 data_t; +template class CutlassDtypeTraits; + +template <> class CutlassDtypeTraits { +public: + typedef float DataType; + typedef float data_t; }; -template <> -class CutlassDtypeTraits { - public: - typedef cutlass::bfloat16_t DataType; - typedef paddle::bfloat16 data_t; +template <> class CutlassDtypeTraits { +public: + typedef cutlass::half_t DataType; + typedef paddle::float16 data_t; +}; + +template <> class CutlassDtypeTraits { +public: + typedef cutlass::bfloat16_t DataType; + typedef paddle::bfloat16 data_t; }; class CutlassGemmConfigMannager { - public: - static CutlassGemmConfigMannager& getInstance() { - static CutlassGemmConfigMannager instance; - return instance; - } +public: + static CutlassGemmConfigMannager &getInstance() { + static CutlassGemmConfigMannager instance; + return instance; + } - CutlassGemmConfigMannager(const CutlassGemmConfigMannager&) = delete; - CutlassGemmConfigMannager& operator=(const CutlassGemmConfigMannager&) = - delete; + CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete; + CutlassGemmConfigMannager & + operator=(const CutlassGemmConfigMannager &) = delete; - void up_date_configs(const nlohmann::json& j) { - std::lock_guard lock(mutex_); - for (auto it = j.begin(); it != j.end(); ++it) { - json_[it.key()] = it.value(); - } + void up_date_configs(const nlohmann::json &j) { + std::lock_guard lock(mutex_); + for (auto it = j.begin(); it != j.end(); ++it) { + json_[it.key()] = it.value(); } + } - nlohmann::json* get_gemm_best_configs(const std::string& config_file_path) { - if (!load_initialized_) { - std::ifstream file(config_file_path); - if (!file.good()) { - throw std::runtime_error( - "cutlass gemm_best_config can not be found, please set " - "gemm_best_config'path as " - "FLAGS_use_cutlass_device_best_config_path, or unset " - "FLAGS_use_cutlass_device_best_config_path to tune " - "gemm_best_config"); - } - json_ = readJsonFromFile(config_file_path); - load_initialized_ = true; - save_initialized_ = false; - } - return &json_; + nlohmann::json *get_gemm_best_configs(const std::string &config_file_path) { + if (!load_initialized_) { + std::ifstream file(config_file_path); + if (!file.good()) { + throw std::runtime_error( + "cutlass gemm_best_config can not be found, please set " + "gemm_best_config'path as " + "FLAGS_use_cutlass_device_best_config_path, or unset " + "FLAGS_use_cutlass_device_best_config_path to tune " + "gemm_best_config"); + } + json_ = readJsonFromFile(config_file_path); + load_initialized_ = true; + save_initialized_ = false; } + return &json_; + } - private: - void save_gemm_best_configs_(const std::string& config_file_path) { - std::ifstream file(config_file_path); - if (!file.good()) { - std::ofstream new_file(config_file_path); - new_file << json_.dump(4); - new_file.close(); - } else { - nlohmann::json old_json = readJsonFromFile(config_file_path); - for (auto it = json_.begin(); it != json_.end(); ++it) { - old_json[it.key()] = it.value(); - } - json_ = old_json; - std::ofstream new_file(config_file_path, - std::ios::out | std::ios::trunc); - new_file << json_.dump(4); - new_file.close(); - file.close(); - } - return; +private: + void save_gemm_best_configs_(const std::string &config_file_path) { + std::ifstream file(config_file_path); + if (!file.good()) { + std::ofstream new_file(config_file_path); + new_file << json_.dump(4); + new_file.close(); + } else { + nlohmann::json old_json = readJsonFromFile(config_file_path); + for (auto it = json_.begin(); it != json_.end(); ++it) { + old_json[it.key()] = it.value(); + } + json_ = old_json; + std::ofstream new_file(config_file_path, std::ios::out | std::ios::trunc); + new_file << json_.dump(4); + new_file.close(); + file.close(); } + return; + } - CutlassGemmConfigMannager() - : json_(nullptr), load_initialized_(false), save_initialized_(true) {} - ~CutlassGemmConfigMannager() { - std::lock_guard lock(mutex_); - if (save_initialized_) { - std::string config_file_path = "fp8_fuse_gemm_config.json"; - save_gemm_best_configs_(config_file_path); - } - save_initialized_ = true; - load_initialized_ = false; - json_.clear(); + CutlassGemmConfigMannager() + : json_(nullptr), load_initialized_(false), save_initialized_(true) {} + ~CutlassGemmConfigMannager() { + std::lock_guard lock(mutex_); + if (save_initialized_) { + std::string config_file_path = "fp8_fuse_gemm_config.json"; + save_gemm_best_configs_(config_file_path); } - mutable std::mutex mutex_; - nlohmann::json json_; - bool load_initialized_; - bool save_initialized_; + save_initialized_ = true; + load_initialized_ = false; + json_.clear(); + } + mutable std::mutex mutex_; + nlohmann::json json_; + bool load_initialized_; + bool save_initialized_; }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h index f6e3fc963..73749020a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_dual_gemm_scale_bias_act.h @@ -15,8 +15,8 @@ #pragma once #include "fp8_common.h" -#include "fuse_dual_gemm_swiglu_template.h" +#include "fuse_dual_gemm_act_template_3x.h" #include "fuse_dual_gemm_geglu_template.h" +#include "fuse_dual_gemm_swiglu_template.h" -bool fp8_fp8_dual_gemm_scale_bias_act( - DualGemmEpilogueAllParams params); +bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params); diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h index bd7ca2765..9d8bd74c9 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fp8_fp8_gemm_scale_bias_act.h @@ -15,12 +15,13 @@ #pragma once #include "fp8_common.h" +#include "fuse_gemm_gelu_template.h" #include "fuse_gemm_noact_template.h" #include "fuse_gemm_relu_template.h" -#include "fuse_gemm_gelu_template.h" #include "fuse_block_gemm_act_template_3x.h" +#include "fuse_gemm_act_template_3x.h" bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params); -bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params); \ No newline at end of file +bool fp8_fp8_block_gemm_scale_bias_act(GemmEpilogueAllParams params); diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h new file mode 100644 index 000000000..943921e14 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h @@ -0,0 +1,173 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/float8.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "fp8_common.h" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp" +#include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp" + +template class Activation = + cutlass::epilogue::thread::SiLu, + bool SwapAB = true> +bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { + using namespace cute; + using ElementA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, cutlass::float_e5m2_t>; + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / + cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementA; // Element type for B matrix operand + using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / + cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementC = ElementA; // Element type for C matrix operands + + using LayoutC = cute::conditional_t; + static constexpr int AlignmentC = + 128 / + cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C matrices + // in units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = ElementA; // Element type for output matrix operands + // using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output + // matrix operands + using LayoutOutput = cute::conditional_t; + static constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using KernelSchedule = MainloopScheduleType; + using EpilogueSchedule = EpilogueScheduleType; + using TileScheduler = TileSchedulerType; + + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using FusionOperation = + cutlass::epilogue::fusion::ScaledAcc; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, + ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilderGated< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, + LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule, Activation, SwapAB>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int arg_m = params.M; + int arg_n = params.N; + ElementA const *ptr_A = reinterpret_cast(params.A); + ElementB const *ptr_B0 = reinterpret_cast(params.B0); + ElementB const *ptr_B1 = reinterpret_cast(params.B1); + if constexpr (SwapAB) { + arg_m = params.N; + arg_n = params.M; + ptr_A = reinterpret_cast(params.B0); + ptr_B0 = reinterpret_cast(params.A); + } + StrideA stride_A = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(arg_m, params.K, params.batch_count)); + StrideB stride_B = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(arg_n, params.K, params.batch_count)); + StrideC stride_C; + StrideD stride_D = cutlass::make_cute_packed_stride( + StrideD{}, cute::make_shape(arg_m, arg_n, params.batch_count)); + + typename Gemm::Arguments arguments = { + cutlass::gemm::GemmUniversalMode::kGemm, + {arg_m, arg_n, params.K, params.batch_count}, + {ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1}, + {{}, // epilogue.thread + nullptr, + stride_C, + reinterpret_cast(params.D), + stride_D}}; + arguments.epilogue.thread.alpha = params.scale_out; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator *allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + // + // Run the GEMM + // + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + return true; +} \ No newline at end of file diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h new file mode 100644 index 000000000..819463175 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h @@ -0,0 +1,151 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "fp8_common.h" + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/util/packed_stride.hpp" + +template < + typename InputType, + typename OutType, + bool hasbias, + template typename Activation, + typename TileShape, + typename ClusterShape, + typename KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized, + typename SM = cutlass::arch::Sm90> +bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { + using namespace cute; + using ElementA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, cutlass::float_e5m2_t>; + using ElementB = ElementA; + using ElementD = + typename std::conditional_t, + cutlass::bfloat16_t, cutlass::half_t>; + using ElementC = std::conditional_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 16 / sizeof(ElementA); + static constexpr int AlignmentB = 16 / sizeof(ElementB); + static constexpr int AlignmentC = hasbias ? 16 / sizeof(ElementC) : 8; + static constexpr int AlignmentD = 16 / sizeof(ElementD); + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + using FusionOperation = + cutlass::epilogue::fusion::LinCombEltAct; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, + AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + // + // Data members + // + + /// Initialization + StrideA stride_A{params.lda, cute::Int<1>{}, params.M * params.lda}; + StrideB stride_B{params.ldb, cute::Int<1>{}, params.N * params.ldb}; + StrideC stride_C{0, cute::Int<1>{}, 0}; + StrideD stride_D{params.ldd, cute::Int<1>{}, params.ldd * params.M}; + + auto a_ptr = reinterpret_cast(const_cast(params.A)); + auto b_ptr = reinterpret_cast(const_cast(params.B)); + auto c_ptr = reinterpret_cast(const_cast(params.bias)); + auto d_ptr = reinterpret_cast(params.D); + + ProblemShapeType problem_size = + ProblemShapeType{params.M, params.N, params.K, params.batch_count}; + + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {a_ptr, stride_A, b_ptr, stride_B}, + {{params.scale}, // epilogue.thread + c_ptr, + stride_C, + d_ptr, + stride_D}}; + if constexpr (hasbias) { + arguments.epilogue.thread.beta = 1.0; + } + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cout << "Gemm::can_implement() failed. " + << cutlassGetStatusString(status) << std::endl; + return false; + } + size_t workspace_size = Gemm::get_workspace_size(arguments); + phi::Allocator *allocator = paddle::GetAllocator(params.place); + auto workspace = allocator->Allocate(workspace_size); + + status = gemm_op(arguments, workspace->ptr(), params.stream); + if (status != cutlass::Status::kSuccess) { + std::cout << "Gemm::run() failed." << cutlassGetStatusString(status) + << std::endl; + return false; + } + return true; +} \ No newline at end of file diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index c5d0a481d..356f30596 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,9 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -156,9 +158,6 @@ struct MoeFCGemm { using LayoutC = typename MapArguments::LayoutC; using ElementScale = ElementC; - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - // Type definitions about the mainloop. using Operator = typename Mma::Operator; using OperatorClass = typename Mma::Operator::OperatorClass; @@ -209,6 +208,13 @@ struct MoeFCGemm { int64_t gemm_n; int64_t gemm_k; + WintQuantMethod quant_method; + + // Extra arguments for wint2.0 + uint8_t* local_scale; + float* code_scale; + float* code_zp; + // Only used by device-level operator GemmCoord* host_problem_sizes; @@ -230,6 +236,10 @@ struct MoeFCGemm { total_rows(-1), gemm_n(0), gemm_k(0), + quant_method(WintQuantMethod::kNone), + local_scale(nullptr), + code_scale(nullptr), + code_zp(nullptr), host_problem_sizes(nullptr) {} /// Ctor @@ -246,6 +256,10 @@ struct MoeFCGemm { int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + WintQuantMethod quant_method, + const uint8_t* local_scale, + const float* code_scale, + const float* code_zp, GemmCoord* host_problem_sizes = nullptr) : problem_count(problem_count), threadblock_count(threadblock_count), @@ -259,8 +273,12 @@ struct MoeFCGemm { total_rows(total_rows), gemm_n(gemm_n), gemm_k(gemm_k), + quant_method(quant_method), + local_scale(const_cast(local_scale)), + code_scale(const_cast(code_scale)), + code_zp(const_cast(code_zp)), host_problem_sizes(nullptr) { - if (platform::is_same::value || + if (quant_method != WintQuantMethod::kNone || platform::is_same::value || platform::is_same::value) { assert(weight_scales); } @@ -284,6 +302,8 @@ struct MoeFCGemm { ElementC* ptr_C; ElementC* ptr_D; + WintQuantMethod quant_method; + // // Methods // @@ -294,7 +314,8 @@ struct MoeFCGemm { ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), - ptr_D(nullptr) {} + ptr_D(nullptr), + quant_method(WintQuantMethod::kNone) {} CUTLASS_HOST_DEVICE Params(Arguments const& args, @@ -313,7 +334,8 @@ struct MoeFCGemm { ptr_B(args.ptr_B), weight_scales(args.weight_scales), ptr_C(args.ptr_C), - ptr_D(args.ptr_D) {} + ptr_D(args.ptr_D), + quant_method(args.quant_method) {} CUTLASS_HOST_DEVICE void update(Arguments const& args, @@ -334,6 +356,7 @@ struct MoeFCGemm { weight_scales = args.weight_scales; ptr_C = args.ptr_C; ptr_D = args.ptr_D; + quant_method = args.quant_method; } }; @@ -358,7 +381,7 @@ struct MoeFCGemm { } static Status can_implement(Arguments const& args) { - if (platform::is_same::value || + if (args.quant_method != WintQuantMethod::kNone || platform::is_same::value || platform::is_same::value) { if (args.weight_scales == nullptr) { CUTLASS_TRACE_HOST( @@ -394,6 +417,7 @@ struct MoeFCGemm { template struct KernelRunner { + CUTLASS_DEVICE static void run_kernel(Params const& params, SharedStorage& shared_storage) { // NOLINT @@ -401,12 +425,14 @@ struct MoeFCGemm { // These types shadow the type-level definitions and support the ability // to implement a 'transposed' GEMM that computes the transposed problems. // + using ElementA = typename Mma::IteratorA::Element; using LayoutA = typename Mma::IteratorA::Layout; using ElementB = typename Mma::IteratorB::Element; using LayoutB = typename Mma::IteratorB::Layout; using ElementC = typename Epilogue::OutputTileIterator::Element; using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; static_assert( @@ -435,6 +461,7 @@ struct MoeFCGemm { GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + // threadblock_offset of C cutlass::gemm::GemmCoord threadblock_offset( int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT @@ -450,6 +477,7 @@ struct MoeFCGemm { rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); } + // begin address offset for A for current tile ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; typename LayoutA::LongIndex ldm_A = gemm_k; @@ -463,14 +491,17 @@ struct MoeFCGemm { : gemm_k * kInterleave; // Compute initial location in logical coordinates + // the begin threadblock_offset of A, which holds the same row id with C cutlass::MatrixCoord tb_offset_A{ threadblock_offset.m(), 0, }; + // the begin threadblock_offset of B, which holds the same column id with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; // Compute position within threadblock @@ -610,6 +641,381 @@ struct MoeFCGemm { ///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Wint2xMoeFCGemm : public MoeFCGemm { + public: + using Base = MoeFCGemm; + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = typename Base::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and + // complex conjugate operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = typename Base::ProblemVisitor; + using Arguments = typename Base::Arguments; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params : Base::Params { + // Extra arguments for wint2.0 + uint8_t* local_scale; + float* code_scale; + float* code_zp; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : Base::Params(), local_scale(nullptr), code_scale(nullptr), code_zp(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, + void* workspace = nullptr, + int tile_count = 0) // NOLINT + : Base::Params(args, workspace, tile_count), + local_scale(args.local_scale), + code_scale(args.code_scale), + code_zp(args.code_zp) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, + void* workspace = nullptr, + int tile_count = 0) { + Base::update(args, workspace, tile_count); + + local_scale = args.local_scale; + code_scale = args.code_scale; + code_zp = args.code_zp; + } + }; + + /// Shared memory storage structure + using SharedStorage = typename Base::SharedStorage; + + public: + + // + // Methods + // + + CUTLASS_DEVICE + Wint2xMoeFCGemm() {} + + static Status can_implement(Arguments const& args) { + if (args.quant_method != WintQuantMethod::kWeightOnlyInt2) { + CUTLASS_TRACE_HOST( + "Wint2xMoeFCGemm::can_implement() - only support weight_only_int2!"); + return Status::kInvalid; + } else if (args.weight_scales == nullptr || args.local_scale == nullptr) { + CUTLASS_TRACE_HOST( + "Wint2xMoeFCGemm::can_implement() - weight_scales and local_scale is expected to be not nullptr!"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + // The dummy template parameter is not used and exists so that we can compile + // this code using a standard earlier than C++17. Prior to C++17, fully + // specialized templates HAD to exists in a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, + SharedStorage& shared_storage) { // NOLINT + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template + struct KernelRunner { + using WeightQuantTraits = WintQuantTraits; + using QuantArguments = typename WeightQuantTraits::Arguments; + + CUTLASS_DEVICE + static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) { + QuantArguments quant_args; + if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { + quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128; + quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n; + quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n; + } + return quant_args; + } + + CUTLASS_DEVICE + static void run_kernel(Params const& params, + SharedStorage& shared_storage) { // NOLINT + // + // These types shadow the type-level definitions and support the ability + // to implement a 'transposed' GEMM that computes the transposed problems. + // + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using QuantElementB = typename WeightQuantTraits::WeightType; + using MmaElementB = typename WeightQuantTraits::MmaWeightType; + + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert( + platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // LayoutB should be RowMajor + using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + // wint2.5 and wint2.0 is quantized and packed along k dimension with group_size 64. + const int64_t quant_gemm_k = WintQuantTraits::CaclPackedDim(gemm_k); + int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // threadblock_offset of C + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT + 0); + + // begin address offset for weight_scale. + ElementScale* weight_scale_ptr = + params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr; + // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Load element pointers. Exchange pointers and strides if working on + // the transpose + int64_t rows_to_jump = 0; + + if (params.problem_visitor.total_rows < 0) { + rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + } else { + rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); + } + + // begin address offset for A for current tile + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + // Compute initial location in logical coordinates + // the begin threadblock_offset of A, which holds the same row id with C + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + // begin address offset for B for current problem_idx, totally num_experts problems + char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; // NOLINT + + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; + typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; + + // the begin threadblock_offset of B, which holds the same column id with C + cutlass::MatrixCoord tb_offset_B{0, + threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; + cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; + + MmaElementB* smem_unzip_B_ptr = nullptr; + if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { + smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); + } + QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); + TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, + byte_ptr_B, + ldm_B, + extent_B, + tb_offset_B, + weight_scale_ptr, + tb_offset_scale, + quant_args); + MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr(); + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B), + ptr_B, + TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B, + thread_idx, + TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the + // previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + tile_dequanter_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = + params.ptr_C ? reinterpret_cast(params.ptr_C) + problem_idx * gemm_n : nullptr; + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, + SharedStorage& shared_storage) { // NOLINT +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910) + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h index b7a0ce77d..c850e77bc 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h @@ -15,16 +15,22 @@ */ #pragma once + #include #include + #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h" +#include "cutlass_extensions/wint_type_traits.h" namespace phi { template + typename WeightQuantTraits /* The quant traits for the MoE weights */> class MoeGemmRunner { public: + using WeightType = typename WeightQuantTraits::WeightType; + using Arguments = typename WeightQuantTraits::Arguments; + MoeGemmRunner(); void moe_gemm_bias_act(const T* A, @@ -38,6 +44,7 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + const Arguments& quant_args_B, std::string activation_type, cudaStream_t stream); @@ -51,6 +58,7 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + const Arguments& quant_args_B, cudaStream_t stream); private: @@ -65,6 +73,7 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + const Arguments& quant_args_B, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); @@ -81,6 +90,7 @@ class MoeGemmRunner { int64_t gemm_n, int64_t gemm_k, int num_experts, + const Arguments& quant_args_B, cudaStream_t stream); private: diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu index faf2f37ea..d8496073f 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - + #pragma once #include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" #include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h" @@ -22,7 +22,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 -template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +template class MoeGemmRunner< + __nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu new file mode 100644 index 000000000..92d63948c --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" +#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h" +#include "helper.h" + +namespace phi { + +#ifdef PADDLE_CUDA_BF16 +template class MoeGemmRunner< + __nv_bfloat16, + cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>; +#endif + +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu index 8b3b77e2f..b82fbc107 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu @@ -21,7 +21,9 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 -template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +template class MoeGemmRunner< + __nv_bfloat16, + cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu index 6756855cf..97fdd104b 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu @@ -22,8 +22,9 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 -template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +template class MoeGemmRunner< + __nv_bfloat16, + cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>; #endif -} // namespace phi - +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu index 3bce2ccdb..a3d34b8e7 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu @@ -21,6 +21,7 @@ namespace phi { -template class MoeGemmRunner; +template class MoeGemmRunner>; -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu new file mode 100644 index 000000000..5d84c9cfc --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" +#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h" +#include "helper.h" + +namespace phi { + +template class MoeGemmRunner< + half, cutlass::WintQuantTraits>; + +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu index cef3f9e8e..51707ebbb 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu @@ -21,6 +21,7 @@ namespace phi { -template class MoeGemmRunner; +template class MoeGemmRunner< + half, cutlass::WintQuantTraits>; -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu index 9e3923ddb..c796f9bbe 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu @@ -21,6 +21,7 @@ namespace phi { -template class MoeGemmRunner; +template class MoeGemmRunner< + half, cutlass::WintQuantTraits>; -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index 524953391..b5cb93ad3 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -24,9 +24,10 @@ #include #include #include -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" +#include "cutlass/array.h" +#include "cutlass/trace.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" @@ -35,8 +36,11 @@ #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h" -#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h" + +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/wint_type_traits.h" + #include "cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h" #include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" @@ -48,17 +52,47 @@ #include "helper.h" namespace phi { -// ============================= Variable batched Gemm things -// =========================== + +template +struct CutlassLayoutB { + using Type = typename MixedGemmArchTraits::LayoutB; +}; + +template +struct CutlassLayoutB { + using Type = cutlass::layout::RowMajor; +}; + +template +struct CutlassGemmKernel { + using Type = + cutlass::gemm::kernel::MoeFCGemm; +}; + +template +struct CutlassGemmKernel { + using Type = + cutlass::gemm::kernel::Wint2xMoeFCGemm; +}; + +// ======================= Variable batched Gemm things ======================= template void generic_moe_gemm_kernelLauncher(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -67,6 +101,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, @@ -86,44 +121,26 @@ void generic_moe_gemm_kernelLauncher(const T* A, "Specialized for half, float"); #endif + using WeightType = typename WeightQuantTraits::WeightType; + static_assert( cutlass::platform::is_same::value || cutlass::platform::is_same::value || - cutlass::platform::is_same::value, - ""); + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)"); // The cutlass type for the input elements. This is needed to convert to // cutlass::half_t if necessary. - using ElementType_ = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::half_t, - T>::type; -#ifdef PADDLE_CUDA_BF16 - using ElementType = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::bfloat16_t, - ElementType_>::type; -#else - using ElementType = ElementType_; -#endif - - using CutlassWeightType_ = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::half_t, - WeightType>::type; -#ifdef PADDLE_CUDA_BF16 - using CutlassWeightType = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::bfloat16_t, - CutlassWeightType_>::type; -#else - using CutlassWeightType = CutlassWeightType_; -#endif + using ElementType = typename cutlass::CutlassDataType::Type; + using CutlassWeightType = typename cutlass::CutlassDataType::Type; + using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType; + using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType; // We need separate config for each architecture since we will target // different tensorcore instructions. For float, we do not target TCs. using MixedGemmArchTraits = cutlass::gemm::kernel:: - MixedGemmArchTraits; + MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; using EpilogueOp = typename Epilogue::Op; // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + using BaseGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, - CutlassWeightType, - typename MixedGemmArchTraits::LayoutB, + CutlassMmaKernelType, + typename CutlassLayoutB::Type, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, @@ -155,14 +172,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; - using GemmKernel = - cutlass::gemm::kernel::MoeFCGemm; - + using GemmKernel = typename CutlassGemmKernel::Type; using GemmGrouped = cutlass::gemm::device::GemmGrouped; if (kernel_occupancy != nullptr) { @@ -181,19 +191,32 @@ void generic_moe_gemm_kernelLauncher(const T* A, typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + const uint8_t* local_scale_B = nullptr; + const float* code_scale_B = nullptr; + const float* code_zp_B = nullptr; + if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { + local_scale_B = quant_args_B.local_scale_ptr; + code_scale_B = quant_args_B.code_scale_ptr; + code_zp_B = quant_args_B.code_zp_ptr; + } + typename GemmGrouped::Arguments args( num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, total_rows, gemm_n, - gemm_k); + gemm_k, + WeightQuantTraits::kQuantMethod, + local_scale_B, + code_scale_B, + code_zp_B); GemmGrouped gemm; @@ -222,7 +245,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, } template struct dispatch_stages { static void dispatch(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -240,6 +263,7 @@ struct dispatch_stages { int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, @@ -253,20 +277,20 @@ struct dispatch_stages { }; template struct dispatch_stages { static void dispatch(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -275,12 +299,13 @@ struct dispatch_stages struct dispatch_stages 2)>::type> { static void dispatch(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -326,12 +352,13 @@ struct dispatch_stages void dispatch_gemm_config(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -369,6 +397,7 @@ void dispatch_gemm_config(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, @@ -376,7 +405,7 @@ void dispatch_gemm_config(const T* A, #define dispatch_stages_macro(STAGE) \ case STAGE: \ dispatch_stages, \ @@ -425,10 +455,11 @@ void dispatch_gemm_config(const T* A, biases, \ C, \ total_rows_before_expert, \ - total_rows, \ + total_rows, \ gemm_n, \ gemm_k, \ num_experts, \ + quant_args_B, \ gemm_config, \ multi_processor_count, \ stream, \ @@ -438,14 +469,14 @@ void dispatch_gemm_config(const T* A, // This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. // This overload is only enabled when T == WeightType. template ::value && - std::is_same::value>::type* = + std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -454,6 +485,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, @@ -474,7 +506,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, default: throw std::runtime_error( "[dispatch_moe_gemm_to_cutlass] Config is invalid for same " - "type MoE tensorop GEMM."); + "type MoE tensorop GEMM for FP16/BF16."); break; } } @@ -483,14 +515,14 @@ void dispatch_moe_gemm_to_cutlass(const T* A, // Overload for quantize MoE GEMMs. We disable some warp configs here since they // will not be used and we can improve compile time template ::value && - !std::is_same::value>::type* = + !std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -499,28 +531,34 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { if constexpr (std::is_same::value) { - switch (gemm_config.tile_config) { - dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); - dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); - case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[dispatch_moe_gemm_to_cutlass] gemm config should have " - "already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[dispatch_moe_gemm_to_cutlass] Config is invalid for " - "mixed type tensorop GEMM."); - break; + if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { + switch (gemm_config.tile_config) { + dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); + dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); + case CutlassTileConfig::Undefined: + throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config should have " + "already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] Config is invalid for " + "mixed type tensorop GEMM for sm70."); + break; + } + } else { + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70."); } } else { switch (gemm_config.tile_config) { @@ -555,12 +593,12 @@ void dispatch_moe_gemm_to_cutlass(const T* A, // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. template < typename T, - typename WeightType, + typename WeightQuantTraits, typename arch, typename EpilogueTag, typename std::enable_if::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -569,6 +607,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, @@ -594,8 +633,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } } -template -MoeGemmRunner::MoeGemmRunner() { +template +MoeGemmRunner::MoeGemmRunner() { int device{-1}; check_cuda_error(cudaGetDevice(&device)); sm_ = getSMVersion(); @@ -603,11 +642,11 @@ MoeGemmRunner::MoeGemmRunner() { &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); } -template +template template -void MoeGemmRunner::dispatch_to_arch( +void MoeGemmRunner::dispatch_to_arch( const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, @@ -616,11 +655,12 @@ void MoeGemmRunner::dispatch_to_arch( int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) { #define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ - dispatch_moe_gemm_to_cutlass( \ + dispatch_moe_gemm_to_cutlass( \ A, \ B, \ weight_scales, \ @@ -631,6 +671,7 @@ void MoeGemmRunner::dispatch_to_arch( gemm_n, \ gemm_k, \ num_experts, \ + quant_args_B, \ gemm_config, \ sm_, \ multi_processor_count_, \ @@ -648,25 +689,28 @@ void MoeGemmRunner::dispatch_to_arch( } } -template +template template -void MoeGemmRunner::run_gemm( +void MoeGemmRunner::run_gemm( const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, - int64_t tune_total_rows, + int64_t actual_total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, cudaStream_t stream) { - static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool is_weight_only = !std::is_same::value; static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true); + static constexpr int warm_time = 5; static constexpr int test_time = 10; auto& gemmConfigManager = GemmConfigManager::Instance(); @@ -676,17 +720,19 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, GemmType::MOEGEMM, dtype, wdtype, num_experts}; CutlassGemmConfig chosen_config; auto chosen_config_optional = - gemmConfigManager.getBestConfig(gemmId, tune_total_rows); + gemmConfigManager.getBestConfig(gemmId, actual_total_rows); if (chosen_config_optional != std::nullopt) { chosen_config = chosen_config_optional.value(); } else { + size_t best_id = -1; float best_time = std::numeric_limits::max(); CutlassGemmConfig best_config; int profile_total_rows = - std::min(gemmConfigManager.nextPowerOfTwo(tune_total_rows), + std::min(gemmConfigManager.nextPowerOfTwo(actual_total_rows), gemmConfigManager.getMaxProfileM()); bool find_one = false; - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + size_t num_candidate_configs_size = candidate_configs.size(); + for (size_t ii = 0; ii < num_candidate_configs_size; ++ii) { try { for (int i = 0; i < warm_time; i++) { dispatch_to_arch(A, @@ -699,6 +745,7 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, num_experts, + quant_args_B, candidate_configs[ii], stream); } @@ -719,6 +766,7 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, num_experts, + quant_args_B, candidate_configs[ii], stream); } @@ -728,7 +776,9 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); + //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; if (elapsed < best_time) { + best_id = ii; best_time = elapsed; best_config = candidate_configs[ii]; } @@ -739,6 +789,7 @@ void MoeGemmRunner::run_gemm( } } if (find_one) { + //std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl; gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config); chosen_config = best_config; } else { @@ -756,23 +807,25 @@ void MoeGemmRunner::run_gemm( gemm_n, gemm_k, num_experts, + quant_args_B, chosen_config, stream); } -template -void MoeGemmRunner::moe_gemm_bias_act( +template +void MoeGemmRunner::moe_gemm_bias_act( const T* A, - const WeightType* B, + const typename WeightQuantTraits::WeightType* B, const T* weight_scales, const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, - int64_t tune_total_rows, + int64_t actual_total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, std::string activation_type, cudaStream_t stream) { if (activation_type == "none") { @@ -784,10 +837,11 @@ void MoeGemmRunner::moe_gemm_bias_act( C, total_rows_before_expert, total_rows, - tune_total_rows, + actual_total_rows, gemm_n, gemm_k, num_experts, + quant_args_B, stream); } else { run_gemm(A, @@ -797,27 +851,30 @@ void MoeGemmRunner::moe_gemm_bias_act( C, total_rows_before_expert, total_rows, - tune_total_rows, + actual_total_rows, gemm_n, gemm_k, num_experts, + quant_args_B, stream); } } } -template -void MoeGemmRunner::moe_gemm(const T* A, - const WeightType* B, - const T* weight_scales, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t tune_total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - cudaStream_t stream) { +template +void MoeGemmRunner::moe_gemm( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t actual_total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + cudaStream_t stream) { run_gemm(A, B, weight_scales, @@ -825,10 +882,11 @@ void MoeGemmRunner::moe_gemm(const T* A, C, total_rows_before_expert, total_rows, - tune_total_rows, + actual_total_rows, gemm_n, gemm_k, num_experts, + quant_args_B, stream); } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu index c427077a8..4cdc7f0b3 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_test.cu @@ -13,58 +13,60 @@ See the License for the specific language governing permissions and limitations under the License. */ // #include "paddle/phi/core/enforce.h" -#include -#include -#include -#include -#include #include "ctime" #include "iostream" #include "stdint.h" #include "stdlib.h" -#include "weight_process_utils.h" -#include -#include -#include -#include -#include -#include -#include #include "w4a4_gemm_configs.h" #include "w4a8_moe_gemm_kernel.h" +#include "weight_process_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // #include "paddle/phi/common/data_type.h" #include "cutlass/numeric_types.h" #include "cutlass/trace.h" #define USE_NVTX #ifdef USE_NVTX +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120900) +#include "nvtx3/nvToolsExt.h" +#else #include "nvToolsExt.h" -// DECLARE_string(cutlass_w4a8_best_config); - -const uint32_t colors[] = { 0xff00ff00, 0xff0000ff, 0xffffff00, 0xffff00ff, 0xff00ffff, 0xffff0000, 0xffffffff }; -const int num_colors = sizeof(colors)/sizeof(uint32_t); +#endif +const uint32_t colors[] = {0xff00ff00, 0xff0000ff, 0xffffff00, 0xffff00ff, + 0xff00ffff, 0xffff0000, 0xffffffff}; +const int num_colors = sizeof(colors) / sizeof(uint32_t); using CutlassTileConfig = CutlassTileConfig; using SplitKStyle = SplitKStyle; using CutlassGemmConfig = CutlassGemmConfig; - -#define PUSH_RANGE(name,cid) { \ - int color_id = cid; \ - color_id = color_id%num_colors;\ - nvtxEventAttributes_t eventAttrib = {0}; \ - eventAttrib.version = NVTX_VERSION; \ - eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \ - eventAttrib.colorType = NVTX_COLOR_ARGB; \ - eventAttrib.color = colors[color_id]; \ - eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \ - eventAttrib.message.ascii = name; \ - nvtxRangePushEx(&eventAttrib); \ -} +#define PUSH_RANGE(name, cid) \ + { \ + int color_id = cid; \ + color_id = color_id % num_colors; \ + nvtxEventAttributes_t eventAttrib = {0}; \ + eventAttrib.version = NVTX_VERSION; \ + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \ + eventAttrib.colorType = NVTX_COLOR_ARGB; \ + eventAttrib.color = colors[color_id]; \ + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \ + eventAttrib.message.ascii = name; \ + nvtxRangePushEx(&eventAttrib); \ + } #define POP_RANGE nvtxRangePop(); #else -#define PUSH_RANGE(name,cid) +#define PUSH_RANGE(name, cid) #define POP_RANGE #endif @@ -77,9 +79,7 @@ using CutlassGemmConfig = CutlassGemmConfig; // } template -static void PrintMatrix(const T* mat_d, - int num, - std::string name, +static void PrintMatrix(const T *mat_d, int num, std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); @@ -102,19 +102,19 @@ static void PrintMatrix(const T* mat_d, outfile.close(); } -uint as_uint(const float x) { return *(uint*)&x; } +uint as_uint(const float x) { return *(uint *)&x; } uint16_t ConvertFloat2Half(const float x) { - const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last - // bit after truncated mantissa - const uint e = (b & 0x7F800000) >> 23; // exponent - const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = - // 0x00800000-0x00001000 = decimal indicator - // flag - initial rounding + const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last + // bit after truncated mantissa + const uint e = (b & 0x7F800000) >> 23; // exponent + const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = + // 0x00800000-0x00001000 = decimal indicator + // flag - initial rounding return (b & 0x80000000) >> 16 | (e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) | ((e < 113) & (e > 101)) * ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | - (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate + (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate } inline float fp32_from_bits(uint32_t w) { @@ -274,9 +274,7 @@ float CPUHalfConvert2Float(const uint16_t h) { return fp32_from_bits(result); } -static void PrintHalfMatrix(const int16_t* mat_d, - int num, - std::string name, +static void PrintHalfMatrix(const int16_t *mat_d, int num, std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(int16_t) * num, cudaMemcpyDeviceToHost); @@ -298,9 +296,7 @@ static void PrintHalfMatrix(const int16_t* mat_d, } template -static void PrintMatrixCPU(const T* mat, - int num, - std::string name, +static void PrintMatrixCPU(const T *mat, int num, std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -319,17 +315,15 @@ static void PrintMatrixCPU(const T* mat, outfile.close(); } -static void PrintMatrixCPU_int4(const int8_t * mat, - int num, - std::string name, - int numOfCols){ +static void PrintMatrixCPU_int4(const int8_t *mat, int num, std::string name, + int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); std::stringstream ss; for (int i = 0; i < num / 2; ++i) { int32_t output_value = mat[i] & 0x0F; ss << static_cast(output_value) << " "; - output_value = (mat[i]>>4) & 0x0F; + output_value = (mat[i] >> 4) & 0x0F; ss << static_cast(output_value) << " "; if ((i * 2) % numOfCols == numOfCols - 2) { ss << std::endl; @@ -337,12 +331,9 @@ static void PrintMatrixCPU_int4(const int8_t * mat, } outfile << ss.str(); outfile.close(); - } template -static void PrintHalfMatrixCPU(const T* mat, - int num, - std::string name, +static void PrintHalfMatrixCPU(const T *mat, int num, std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -358,8 +349,8 @@ static void PrintHalfMatrixCPU(const T* mat, } template -void naive_matmul( - const T* a, const T* b, outputT* c, size_t m, size_t n, size_t k) { +void naive_matmul(const T *a, const T *b, outputT *c, size_t m, size_t n, + size_t k) { for (int ik = 0; ik < k; ik++) { for (int im = 0; im < m; im++) { for (int in = 0; in < n; in++) { @@ -370,39 +361,37 @@ void naive_matmul( } template -void naive_matmul_fused_dequantize_nf4(const T* a, - const T* b, - const ScaleType* col_scale, - const ScaleType* row_scale, - const int32_t* nf4_look_up_table, - outputT* c, - size_t num_experts, - int64_t* total_rows_before_experts, - size_t total_rows, - size_t n, - size_t k) { - // PrintMatrixCPU( - // a, total_rows * k, "naive_matmul_a", k); - // PrintMatrixCPU( - // b, num_experts*k*n, "naive_matmul_b", n); - // PrintMatrixCPU( - // c, total_rows * n, "naive_matmul_c", n); - // PrintMatrixCPU( - // row_scale, total_rows, "naive_matmul_row_scale", 1); +void naive_matmul_fused_dequantize_nf4(const T *a, const T *b, + const ScaleType *col_scale, + const ScaleType *row_scale, + const int32_t *nf4_look_up_table, + outputT *c, size_t num_experts, + int64_t *total_rows_before_experts, + size_t total_rows, size_t n, size_t k) { + // PrintMatrixCPU( + // a, total_rows * k, "naive_matmul_a", k); + // PrintMatrixCPU( + // b, num_experts*k*n, "naive_matmul_b", n); + // PrintMatrixCPU( + // c, total_rows * n, "naive_matmul_c", n); + // PrintMatrixCPU( + // row_scale, total_rows, "naive_matmul_row_scale", 1); - // PrintMatrixCPU( - // col_scale, num_experts * n, "naive_matmul_col_scale", n); + // PrintMatrixCPU( + // col_scale, num_experts * n, "naive_matmul_col_scale", n); - // PrintMatrixCPU( - // nf4_look_up_table, 16, "naive_matmul_nf4_lut", 1); - // std::cout<<"####nf4_look_up_table"<(&loop_up_table); + // PrintMatrixCPU( + // nf4_look_up_table, 16, "naive_matmul_nf4_lut", 1); + // std::cout<<"####nf4_look_up_table"<(&loop_up_table); for (int ie = 0; ie < num_experts; ie++) { int im_start, im_end; if (ie == 0) { @@ -422,9 +411,10 @@ void naive_matmul_fused_dequantize_nf4(const T* a, // std::cout<(a_val)<<", "; // std::cout<(b[ie * n * k + ik * n + in])<<", "; // std::cout<(b_val)<<", "; - // std::cout<(nf4_look_up_table[b_val])<<", " << std::endl; - // std::cout<(b_val); + // std::cout<(nf4_look_up_table[b_val])<<", " << + // std::endl; std::cout<(b_val); int32_t matmul_res = static_cast(a_val) * b_val_int32; // std::cout<(*reinterpret_cast(&r_val)); + float row_scale_val = + static_cast(*reinterpret_cast(&r_val)); // float row_scale_val = 1.0; - uint16_t c_val = ConvertFloat2Half(col_scale ? col_scale[ie * n + in] * 112 : 1.0); - float col_scale_val = static_cast(*reinterpret_cast(&c_val)); + uint16_t c_val = + ConvertFloat2Half(col_scale ? col_scale[ie * n + in] * 112 : 1.0); + float col_scale_val = + static_cast(*reinterpret_cast(&c_val)); // printf("##### (%d,%d) accu_val = %d\n",im,in, accum_val); - uint16_t res = ConvertFloat2Half(static_cast(accum_val) * col_scale_val * row_scale_val); - c[im * n + in] = static_cast(*reinterpret_cast(&res)); + uint16_t res = ConvertFloat2Half(static_cast(accum_val) * + col_scale_val * row_scale_val); + c[im * n + in] = static_cast(*reinterpret_cast(&res)); } } } - // PrintMatrixCPU( - // c, total_rows * n, "naive_matmul_c_computed", n); + // PrintMatrixCPU( + // c, total_rows * n, "naive_matmul_c_computed", n); } // Author (zhengzekang): we use float to monitor half matmul in CPU. -void CheckHalfDiff(int16_t* device_res, - float* host_result, - size_t elem_cnt, - float atol, - float rtol) { +void CheckHalfDiff(int16_t *device_res, float *host_result, size_t elem_cnt, + float atol, float rtol) { std::vector device_data(elem_cnt); - cudaMemcpy(device_data.data(), - device_res, - sizeof(int16_t) * elem_cnt, + cudaMemcpy(device_data.data(), device_res, sizeof(int16_t) * elem_cnt, cudaMemcpyDeviceToHost); for (size_t i = 0; i < elem_cnt; i++) { @@ -470,17 +459,13 @@ void CheckHalfDiff(int16_t* device_res, printf( "Here in Idx: %d, CUDA result is: %f, Host result is: %f, absolute " "diff val is: %f \n", - i, - device_res_val, - host_res_val, - absolute_diff); + i, device_res_val, host_res_val, absolute_diff); return; } } printf("======= Check Success! =======\n"); } - // uint16_t float_to_half(const float x) { // IEEE-754 16-bit floating-point // format (without infinity): 1-5-10, exp-15, +-131008.0, +-6.1035156E-5, // +-5.9604645E-8, 3.311 digits @@ -495,7 +480,7 @@ void CheckHalfDiff(int16_t* device_res, // } template -__global__ void CUDAPrintHalfMatrix(T* output, int m, int n) { +__global__ void CUDAPrintHalfMatrix(T *output, int m, int n) { for (int row_idx = 0; row_idx < m; row_idx++) { for (int col_idx = 0; col_idx < n; col_idx++) { // printf("%d ", static_cast(static_cast(output[row_idx * @@ -506,62 +491,64 @@ __global__ void CUDAPrintHalfMatrix(T* output, int m, int n) { } } - -CutlassGemmConfig GetGemmConfig(int token_nums, std::vector & gemm_config_tuple){ +CutlassGemmConfig GetGemmConfig(int token_nums, + std::vector &gemm_config_tuple) { int len_of_gemm_config_tuple = gemm_config_tuple.size(); - if(len_of_gemm_config_tuple == 0){ - CutlassGemmConfig gemm_config = CutlassGemmConfig{CutlassTileConfig::Undefined,SplitKStyle::NO_SPLIT_K,-1,-1}; + if (len_of_gemm_config_tuple == 0) { + CutlassGemmConfig gemm_config = CutlassGemmConfig{ + CutlassTileConfig::Undefined, SplitKStyle::NO_SPLIT_K, -1, -1}; return gemm_config; } CutlassGemmConfig gemm_config = CutlassGemmConfig{ - CutlassTileConfig(gemm_config_tuple[len_of_gemm_config_tuple - 4]), - SplitKStyle(gemm_config_tuple[len_of_gemm_config_tuple - 3]), - gemm_config_tuple[len_of_gemm_config_tuple - 2], - gemm_config_tuple[len_of_gemm_config_tuple-1] - }; + CutlassTileConfig(gemm_config_tuple[len_of_gemm_config_tuple - 4]), + SplitKStyle(gemm_config_tuple[len_of_gemm_config_tuple - 3]), + gemm_config_tuple[len_of_gemm_config_tuple - 2], + gemm_config_tuple[len_of_gemm_config_tuple - 1]}; // 0,1,2,3 ,4 ,5 ,6 // gemm_config_tuple:[m,n,k,tile_config,split_k_style,split_k_factor,stages] - for(int i=0; i -void get_tensor_from_file(const std::string file_path, int64_t numel, T* tensor_ptr){ +template +void get_tensor_from_file(const std::string file_path, int64_t numel, + T *tensor_ptr) { std::fstream datafile; datafile.open(file_path, std::ios_base::in | std::ios_base::out); int index = 0; std::string line; while (std::getline(datafile, line)) { - std::istringstream iss(line); - if(index == 0){ - std::cout<> number) { - tensor_ptr[index] = static_cast(number); - if(index==0){ - std::cout<(tensor_ptr[0])<> number) { + tensor_ptr[index] = static_cast(number); + if (index == 0) { + std::cout << file_path << ": " << number << "-" + << static_cast(tensor_ptr[0]) << std::endl; } + index++; + } } - std::cout< uniform(-0.02, 0.02); std::default_random_engine random_engine(0); @@ -573,28 +560,29 @@ int main(int argc, char* argv[]) { size_t tokens_per_expert = strtol(argv[4], nullptr, 0); size_t total_rows = num_experts * tokens_per_expert; std::vector total_rows_before_experts; - std::cout<<"total_rows_before_experts: "; - for(int i = 0; i < num_experts; ++i){ - total_rows_before_experts.push_back(tokens_per_expert*(i+1)); - std::cout<= 6) { do_check = strtol(argv[5], nullptr, 0); } if (do_check) { - std::cout<<"####do check#####"<= 7) { do_gemm_config_searching = strtol(argv[6], nullptr, 0); } if (do_gemm_config_searching) { - std::cout<<"####do gemm config searching#####"<= 8){ + if (argc >= 8) { gemm_config_search_log_file = argv[8]; } std::ifstream gemm_config_file(gemm_config_search_log_file); std::string a_data_file = ""; - if(argc >= 9){ + if (argc >= 9) { a_data_file = argv[9]; } std::string b_data_file = ""; - if(argc >= 10){ + if (argc >= 10) { b_data_file = argv[10]; } - std::string row_scale_data_file = ""; - if(argc >= 11){ + if (argc >= 11) { row_scale_data_file = argv[11]; } - std::string col_scale_data_file = ""; - if(argc >= 12){ + if (argc >= 12) { col_scale_data_file = argv[12]; } std::vector config_vec; if (gemm_config_file.is_open()) { - std::string line; - while (std::getline(gemm_config_file, line)) { - // using printf() in all tests for consistency - // printf("%s", line.c_str()); - if(line.find("#####best_gemm_config_tuple#####")!=std::string::npos){ - std::cout<> temp) { - config_vec.push_back(temp); - } - } + std::string line; + while (std::getline(gemm_config_file, line)) { + // using printf() in all tests for consistency + // printf("%s", line.c_str()); + if (line.find("#####best_gemm_config_tuple#####") != std::string::npos) { + std::cout << line << std::endl; + std::string config_str = line.substr(32, std::string::npos); + std::istringstream in_str(config_str); + int temp; + while (in_str >> temp) { + config_vec.push_back(temp); + } } - gemm_config_file.close(); + } + gemm_config_file.close(); } // auto best_gemm_config = GetGemmConfig(m, config_vec); const auto kWarmTime = 1; const auto kTestTime = 100; - auto mixed_gemm_runner = - W4A8MoeGemmRunner(); + auto mixed_gemm_runner = W4A8MoeGemmRunner(); // int mixgemm_max_size = std::max(m, k); - int mixgemm_workspace_size_bytes = 1 * 1024*1024*1024; // 1G workspace - std::cout<<"mixgemm_workspace_size_bytes: "< a_int(total_rows * k); if (do_check) { - if(a_data_file==""){ + if (a_data_file == "") { for (int i = 0; i < total_rows * k; i++) { - // a_int[i] = 1; - a_int[i] = rand() % 16; + // a_int[i] = 1; + a_int[i] = rand() % 16; // a_int[i] = rand() % 128 - 64; // if(i>=k){ // // a_int[i] = a_int[i%k]; @@ -677,16 +662,17 @@ int main(int argc, char* argv[]) { // } } } else { - std::cout<<"get a data from: "<< a_data_file << std::endl; - get_tensor_from_file(a_data_file, total_rows*k, a_int.data()); + std::cout << "get a data from: " << a_data_file << std::endl; + get_tensor_from_file(a_data_file, total_rows * k, + a_int.data()); } // PrintMatrixCPU(a_int.data(),total_rows*k,"a_int8_cpu",n); } std::vector b_int(num_experts * k * n); if (do_check) { - for (int ii = 0 ; ii < num_experts; ++ii) { - for (int i = ii*k*n; i < (ii+1)*k * n; i++) { + for (int ii = 0; ii < num_experts; ++ii) { + for (int i = ii * k * n; i < (ii + 1) * k * n; i++) { // author zhengzekang b_int[i] = rand() % 16; // b_int[i] = 1; @@ -706,99 +692,103 @@ int main(int argc, char* argv[]) { // } } } - // PrintMatrixCPU(b_int.data(),num_experts * k * n,"b_int8_cpu_init",n); + // PrintMatrixCPU(b_int.data(),num_experts * k * + // n,"b_int8_cpu_init",n); } - - std::vector nf4_look_up_table(16); std::vector nf4_look_up_table_compress(16); if (do_check) { - for(int i=0;i<4;++i){ - nf4_look_up_table_compress[i]=0; + for (int i = 0; i < 4; ++i) { + nf4_look_up_table_compress[i] = 0; } - for(int i=0;i<16;++i){ - int32_t left4i=i<<4; - int8_t tmp = * reinterpret_cast(&(left4i)); + for (int i = 0; i < 16; ++i) { + int32_t left4i = i << 4; + int8_t tmp = *reinterpret_cast(&(left4i)); int32_t tmp_int32 = static_cast(tmp); - nf4_look_up_table[i]=tmp_int32; + nf4_look_up_table[i] = tmp_int32; } - for(int i=0;i<16;++i){ - nf4_look_up_table_compress[i] = (static_cast(nf4_look_up_table[i])); + for (int i = 0; i < 16; ++i) { + nf4_look_up_table_compress[i] = + (static_cast(nf4_look_up_table[i])); } - std::cout<<"####nf4_look_up_table"< packed_b_int(num_experts * k * n / 2); if (do_check) { - for (int ie = 0 ; ie < num_experts; ++ie) { + for (int ie = 0; ie < num_experts; ++ie) { int offset = ie * k * n / 2; - for (int packed_i = 0; packed_i < k * n / 2; packed_i++){ + for (int packed_i = 0; packed_i < k * n / 2; packed_i++) { packed_b_int[offset + packed_i] = 0; - packed_b_int[offset + packed_i] |= b_int[(offset + packed_i)*2] & 0x0f; - packed_b_int[offset + packed_i] |= (b_int[(offset + packed_i)*2 + 1] & 0x0f) << 4; + packed_b_int[offset + packed_i] |= + b_int[(offset + packed_i) * 2] & 0x0f; + packed_b_int[offset + packed_i] |= + (b_int[(offset + packed_i) * 2 + 1] & 0x0f) << 4; } } } std::vector b_int_processed(num_experts * k * n / 2); - std::vector b_int_processed_2(num_experts * k * n/2); - std::vector b_int_processed_3(num_experts * k * n/2); + std::vector b_int_processed_2(num_experts * k * n / 2); + std::vector b_int_processed_3(num_experts * k * n / 2); if (do_check) { printf("do check\n"); - if(b_data_file == "") { - for (int ie = 0; ie < num_experts ; ie++) { + if (b_data_file == "") { + for (int ie = 0; ie < num_experts; ie++) { // PrintMatrixCPU_int4(packed_b_int.data(),num_experts*k*n,"w4a8_packed_b_int4",n); - permute_B_rows_for_mixed_gemm_int4<4>(b_int_processed.data() + ie * k * n / 2, - packed_b_int.data() + ie * k * n / 2, - std::vector{k, n}, - (int64_t)80); + permute_B_rows_for_mixed_gemm_int4<4>( + b_int_processed.data() + ie * k * n / 2, + packed_b_int.data() + ie * k * n / 2, std::vector{k, n}, + (int64_t)80); // PrintMatrixCPU_int4(b_int_processed.data(),num_experts*k*n,"w4a8_permuted_int4",n); - std::cout<<"before subbyte_transpose_impl_int4"<{k, n}); + b_int_processed.data() + ie * k * n / 2, + std::vector{k, n}); // PrintMatrixCPU_int4(b_int_processed_2.data(),num_experts*k*n,"w4a8_subbyte_transpose_impl_int4",k); - interleave_column_major_tensor_int4(b_int_processed_3.data() + ie * k * n / 2, - b_int_processed_2.data() + ie * k * n / 2, - std::vector{k, n}); + interleave_column_major_tensor_int4( + b_int_processed_3.data() + ie * k * n / 2, + b_int_processed_2.data() + ie * k * n / 2, + std::vector{k, n}); // PrintMatrixCPU_int4(b_int_processed_3.data(),num_experts*k*n,"w4a8_interleave_column_major_tensor_int4",k); - add_bias_and_interleave_int4s_inplace(b_int_processed_3.data() + ie * k * n / 2, k * n); + add_bias_and_interleave_int4s_inplace( + b_int_processed_3.data() + ie * k * n / 2, k * n); } } else { - get_tensor_from_file(b_data_file,num_experts*k*n/2, b_int_processed_3.data()); + get_tensor_from_file( + b_data_file, num_experts * k * n / 2, b_int_processed_3.data()); } - // PrintMatrixCPU_int4(b_int_processed_3.data(), - // num_experts*k*n/2, - // "w4a8_add_bias_and_interleave_int4s_inplace", - // k); + // PrintMatrixCPU_int4(b_int_processed_3.data(), + // num_experts*k*n/2, + // "w4a8_add_bias_and_interleave_int4s_inplace", + // k); // PrintMatrixCPU(b_int_processed_3.data(),num_experts*k*n/2,"b_int8_cpu",n); // TODO(zhengzekang): temporary use uint16_t instead of half. } - std::cout<<"done weight interleaved;"< c_float(total_rows * n); if (do_check) { @@ -807,7 +797,7 @@ int main(int argc, char* argv[]) { } } - std::cout<<"#### 2"< c_half(total_rows * n); if (do_check) { @@ -816,11 +806,11 @@ int main(int argc, char* argv[]) { } } - std::cout<<"#### 3"< row_scale_float(total_rows); if (do_check) { - if(row_scale_data_file==""){ + if (row_scale_data_file == "") { for (int32_t i = 0; i < row_scale_float.size(); i++) { // row_scale_float[i] = 0.1; // row_scale_float[i] = uniform(random_engine) * 0.1; @@ -830,29 +820,31 @@ int main(int argc, char* argv[]) { // } } } else { - get_tensor_from_file(row_scale_data_file,total_rows,row_scale_float.data()); + get_tensor_from_file(row_scale_data_file, total_rows, + row_scale_float.data()); } // PrintMatrixCPU(row_scale_float.data(),total_rows,"row_scale_float_cpu",total_rows); } - std::vector col_scale_float(num_experts*n); + std::vector col_scale_float(num_experts * n); if (do_check) { - if(col_scale_data_file == "") { + if (col_scale_data_file == "") { for (int32_t i = 0; i < col_scale_float.size(); i++) { // col_scale_float[i] = 0.04; - col_scale_float[i] = uniform(random_engine) * 0.06 * uniform(random_engine) * 0.1; + col_scale_float[i] = + uniform(random_engine) * 0.06 * uniform(random_engine) * 0.1; // col_scale_float[i] = 0; // if(i<1){ // col_scale_float[i] = 1; // } } } else { - get_tensor_from_file(col_scale_data_file,num_experts*n,col_scale_float.data()); + get_tensor_from_file(col_scale_data_file, num_experts * n, + col_scale_float.data()); } // PrintMatrixCPU(col_scale_float.data(),num_experts*n,"col_scale_float_cpu",n); } - std::vector row_scale_half(total_rows); if (do_check) { for (int32_t i = 0; i < row_scale_half.size(); i++) { @@ -861,8 +853,7 @@ int main(int argc, char* argv[]) { } } - - std::vector col_scale_half(num_experts*n); + std::vector col_scale_half(num_experts * n); if (do_check) { for (int32_t i = 0; i < col_scale_float.size(); i++) { // col_scale_float[i] = 1; @@ -870,86 +861,76 @@ int main(int argc, char* argv[]) { } } - std::cout<<"done c init"<(d_a_int), - reinterpret_cast((void*)d_b_int), - cutlass::epilogue::QuantMode::PerTokenChannelQuant, - reinterpret_cast(d_col_scale_half), - reinterpret_cast(d_row_scale_half), - reinterpret_cast(d_nf4_look_up_table), - reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), - -1, - total_rows, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - num_experts, - 0, - test_config); + mixed_gemm_runner.moe_gemm( + reinterpret_cast(d_a_int), + reinterpret_cast((void *)d_b_int), + cutlass::epilogue::QuantMode::PerTokenChannelQuant, + reinterpret_cast(d_col_scale_half), + reinterpret_cast(d_row_scale_half), + reinterpret_cast(d_nf4_look_up_table), + reinterpret_cast(d_c_int), + reinterpret_cast(d_total_rows_before_experts), -1, + total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, + num_experts, 0, test_config); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { std::cout << "error: " << cudaGetErrorString(err) << std::endl; } else { - std::cout<<"cuda success" < 0){ // cudaDeviceSynchronize(); @@ -959,18 +940,25 @@ int main(int argc, char* argv[]) { // std::string nvtx_name = "int4_gemm_" + std::to_string(m) + "-" // + std::to_string(n) + "-" // + std::to_string(k) + "-" - // + std::to_string(static_cast::type>(best_gemm_config.tile_config)) + "-" - // + std::to_string(static_cast::type>(best_gemm_config.split_k_style))+"-" - // + std::to_string(best_gemm_config.split_k_factor)+"-" - // + std::to_string(best_gemm_config.stages); + // + + // std::to_string(static_cast::type>(best_gemm_config.tile_config)) + // + "-" + // + + // std::to_string(static_cast::type>(best_gemm_config.split_k_style))+"-" + // + + // std::to_string(best_gemm_config.split_k_factor)+"-" + // + + // std::to_string(best_gemm_config.stages); // PUSH_RANGE(nvtx_name.c_str(), 1) // } // mixed_gemm_runner.moe_gemm(reinterpret_cast(d_a_int), - // reinterpret_cast((void*)d_b_int), + // reinterpret_cast((void*)d_b_int), // QuantMode::PerTokenChannelQuant, // reinterpret_cast(d_col_scale_half), // reinterpret_cast(d_row_scale_half), - // reinterpret_cast(d_nf4_look_up_table), + // reinterpret_cast(d_nf4_look_up_table), // reinterpret_cast(d_c_int), // total_rows_before_exports, // m, @@ -987,80 +975,88 @@ int main(int argc, char* argv[]) { // } // cudaDeviceSynchronize(); // auto stop = std::chrono::system_clock::now(); - // auto duration = std::chrono::duration_cast((stop - start)); - // // std::cout<<"avg time for "<((stop - start)); + // // std::cout<<"avg time for "< all_cutlass_tile_configs{ - CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape32x256x64_WarpShape32x64x64, - CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, + CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape32x256x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, }; std::vector all_split_k_style{SplitKStyle::NO_SPLIT_K}; - - for (auto & tile_config : all_cutlass_tile_configs ){ - for (auto & split_k_style : all_split_k_style){ - for (int stages = 3; stages<=7; ++stages){ - for (int split_k_factor = 1; split_k_factor <=1; split_k_factor*=2){ - auto test_gemm_config = CutlassGemmConfig{tile_config,split_k_style,split_k_factor,stages}; + for (auto &tile_config : all_cutlass_tile_configs) { + for (auto &split_k_style : all_split_k_style) { + for (int stages = 3; stages <= 7; ++stages) { + for (int split_k_factor = 1; split_k_factor <= 1; + split_k_factor *= 2) { + auto test_gemm_config = CutlassGemmConfig{ + tile_config, split_k_style, split_k_factor, stages}; cudaEvent_t begin, end; cudaDeviceSynchronize(); cudaEventCreate(&begin); cudaEventCreate(&end); cudaEventRecord(begin, 0); - for (int i = 0; i(d_a_int), - reinterpret_cast((void*)d_b_int), - cutlass::epilogue::QuantMode::PerTokenChannelQuant, - reinterpret_cast(d_col_scale_half), - reinterpret_cast(d_row_scale_half), - reinterpret_cast(d_nf4_look_up_table), - reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), - -1, - total_rows, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - num_experts, - 0, - test_gemm_config); + mixed_gemm_runner.moe_gemm( + reinterpret_cast(d_a_int), + reinterpret_cast((void *)d_b_int), + cutlass::epilogue::QuantMode::PerTokenChannelQuant, + reinterpret_cast(d_col_scale_half), + reinterpret_cast(d_row_scale_half), + reinterpret_cast(d_nf4_look_up_table), + reinterpret_cast(d_c_int), + reinterpret_cast(d_total_rows_before_experts), -1, + total_rows, n, k, mixgemm_workspace_data, + mixgemm_workspace_size_bytes, num_experts, 0, + test_gemm_config); } cudaEventRecord(end, 0); auto cuda_error = cudaDeviceSynchronize(); float cost_time; cudaEventElapsedTime(&cost_time, begin, end); - float avg_time = cost_time/static_cast(kTestTime) * 1000; + float avg_time = cost_time / static_cast(kTestTime) * 1000; if (cuda_error != cudaSuccess) { avg_time = 999999999; - std::cout<<"#### test gemm_config, error " - << " with split-k factor: "<::type>(test_gemm_config.tile_config) - << " split_k_style: "<::type>(test_gemm_config.split_k_style)<<" stages: "<< test_gemm_config.stages <::type>( + test_gemm_config.tile_config) + << " split_k_style: " + << static_cast::type>( + test_gemm_config.split_k_style) + << " stages: " << test_gemm_config.stages << std::endl; } - std::cout<<"#### test gemm_config, avg_time: "<::type>(test_gemm_config.tile_config) - << " split_k_style: "<::type>(test_gemm_config.split_k_style)<<" stages: "<< test_gemm_config.stages <::type>( + test_gemm_config.tile_config) + << " split_k_style: " + << static_cast::type>( + test_gemm_config.split_k_style) + << " stages: " << test_gemm_config.stages << std::endl; - if(avg_time::type>(best_config.tile_config) - << " split_k_style: "<::type>(best_config.split_k_style)<<" stages: "<< best_config.stages <::type>(best_config.tile_config)<<" " - <::type>(best_config.split_k_style)<<" " - <::type>( + best_config.tile_config) + << " split_k_style: " + << static_cast::type>( + best_config.split_k_style) + << " stages: " << best_config.stages << std::endl; + std::cout << "#####best_gemm_config_tuple##### " << total_rows << " " << n + << " " << k << " " << num_experts << " " + << static_cast::type>( + best_config.tile_config) + << " " + << static_cast::type>( + best_config.split_k_style) + << " " << best_config.split_k_factor << " " << best_config.stages + << std::endl; - - // std::string output_config_path = is_encryption ? "moe_w4a8_tuned_config.config" : "moe_w4a8_tuned_config.csv"; + // std::string output_config_path = is_encryption ? + // "moe_w4a8_tuned_config.config" : "moe_w4a8_tuned_config.csv"; std::string output_config_path = "moe_w4a8_tuned_config.csv"; - int fd = open(output_config_path.c_str(), O_WRONLY | O_CREAT | O_APPEND, 0644); - if (fd == -1) { - perror("open error"); - return 1; + int fd = + open(output_config_path.c_str(), O_WRONLY | O_CREAT | O_APPEND, 0644); + if (fd == -1) { + perror("open error"); + return 1; } std::ofstream outfile; if (flock(fd, LOCK_EX) == -1) { - perror("flock error"); - close(fd); - return 1; + perror("flock error"); + close(fd); + return 1; } outfile.open(output_config_path, std::ios::app); - outfile << total_rows << "," - << n << "," - << k << "," - << num_experts << "," - << static_cast::type>(best_config.tile_config) << "," - << static_cast::type>(best_config.split_k_style)<< "," - << best_config.split_k_factor << "," - << best_config.stages <<"\n"; + outfile << total_rows << "," << n << "," << k << "," << num_experts << "," + << static_cast::type>( + best_config.tile_config) + << "," + << static_cast::type>( + best_config.split_k_style) + << "," << best_config.split_k_factor << "," << best_config.stages + << "\n"; // if (!is_encryption) { // outfile << tokens_per_expert << "," // << n << "," // << k << "," - // << static_cast::type>(best_config.tile_config) << "," - // << static_cast::type>(best_config.split_k_style)<< "," + // << + // static_cast::type>(best_config.tile_config) + // << "," + // << + // static_cast::type>(best_config.split_k_style)<< + // "," // << best_config.split_k_factor << "," // << best_config.stages <<"\n"; // } else { @@ -1121,8 +1128,12 @@ int main(int argc, char* argv[]) { // ss << tokens_per_expert << "," // << n << "," // << k << "," - // << static_cast::type>(best_config.tile_config) << "," - // << static_cast::type>(best_config.split_k_style) << "," + // << + // static_cast::type>(best_config.tile_config) + // << "," + // << + // static_cast::type>(best_config.split_k_style) + // << "," // << best_config.split_k_factor << "," // << best_config.stages; // std::string encrypted_str = paddle::operators::base64_encode(ss.str()); @@ -1130,84 +1141,76 @@ int main(int argc, char* argv[]) { // } outfile.flush(); if (flock(fd, LOCK_UN) == -1) { - perror("flock error (unlock)"); - // 注意:即使解锁失败,也应尽量关闭文件描述符 + perror("flock error (unlock)"); + // 注意:即使解锁失败,也应尽量关闭文件描述符 } outfile.close(); close(fd); - - if (do_check) { std::cout << "=== do accuracy check " << std::endl; - cudaMemset(d_c_int, 0, total_rows*n*sizeof(uint16_t)); - PrintHalfMatrix(static_cast(d_c_int), - total_rows * n, - "CUDA_c_dequantize_fp16_output_before_gemm", - n); + cudaMemset(d_c_int, 0, total_rows * n * sizeof(uint16_t)); + PrintHalfMatrix(static_cast(d_c_int), total_rows * n, + "CUDA_c_dequantize_fp16_output_before_gemm", n); - mixed_gemm_runner.moe_gemm(reinterpret_cast(d_a_int), - reinterpret_cast((void*)d_b_int), - cutlass::epilogue::QuantMode::PerChannelQuant, - reinterpret_cast(d_col_scale_half), - nullptr, // reinterpret_cast(d_row_scale_half), - nullptr, // reinterpret_cast(d_nf4_look_up_table), - reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), - -1, - total_rows, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - num_experts, - 0); + mixed_gemm_runner.moe_gemm( + reinterpret_cast(d_a_int), + reinterpret_cast((void *)d_b_int), + cutlass::epilogue::QuantMode::PerChannelQuant, + reinterpret_cast(d_col_scale_half), + nullptr, // reinterpret_cast(d_row_scale_half), + nullptr, // reinterpret_cast(d_nf4_look_up_table), + reinterpret_cast(d_c_int), + reinterpret_cast(d_total_rows_before_experts), -1, + total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, + num_experts, 0); cudaDeviceSynchronize(); - // PrintMatrix(reinterpret_cast(d_nf4_look_up_table),4,"d_nf4_look_up_table",1); + // PrintMatrix(reinterpret_cast(d_nf4_look_up_table),4,"d_nf4_look_up_table",1); printf("##### d_nf4_look_up_table address: %p \n", d_nf4_look_up_table); - naive_matmul_fused_dequantize_nf4(a_int.data(), - b_int.data(), - col_scale_float.data(), - nullptr, // row_scale_float.data(), - nullptr, // nf4_look_up_table.data(), - c_float.data(), - num_experts, - total_rows_before_experts.data(), - total_rows, - n, - k); - PrintMatrixCPU( - c_float.data(), total_rows * n, "CPU_c_fake_fp16_dequantize_output_base", n); - PrintHalfMatrix(static_cast(d_c_int), - total_rows * n, - "CUDA_c_dequantize_fp16_output", - n); - CheckHalfDiff( - static_cast(d_c_int), c_float.data(), total_rows * n, 1e-4, 1e-2); + naive_matmul_fused_dequantize_nf4( + a_int.data(), b_int.data(), col_scale_float.data(), + nullptr, // row_scale_float.data(), + nullptr, // nf4_look_up_table.data(), + c_float.data(), num_experts, total_rows_before_experts.data(), + total_rows, n, k); + PrintMatrixCPU(c_float.data(), total_rows * n, + "CPU_c_fake_fp16_dequantize_output_base", n); + PrintHalfMatrix(static_cast(d_c_int), total_rows * n, + "CUDA_c_dequantize_fp16_output", n); + CheckHalfDiff(static_cast(d_c_int), c_float.data(), + total_rows * n, 1e-4, 1e-2); } - // if(kTestTime > 0){ // cudaDeviceSynchronize(); // auto start = std::chrono::system_clock::now(); // for (int i = 0; i < kTestTime; i++) { // if(i == 0){ - // std::string nvtx_name = "int4_gemm_" + std::to_string(tokens_per_expert) + "-" + // std::string nvtx_name = "int4_gemm_" + + // std::to_string(tokens_per_expert) + "-" // + std::to_string(n) + "-" // + std::to_string(k) + "-" - // + std::to_string(static_cast::type>(best_config.tile_config)) + "-" - // + std::to_string(static_cast::type>(best_config.split_k_style))+"-" - // + std::to_string(best_config.split_k_factor)+"-" - // + std::to_string(best_config.stages); + // + + // std::to_string(static_cast::type>(best_config.tile_config)) + // + "-" + // + + // std::to_string(static_cast::type>(best_config.split_k_style))+"-" + // + + // std::to_string(best_config.split_k_factor)+"-" + // + + // std::to_string(best_config.stages); // PUSH_RANGE(nvtx_name.c_str(), 1) // } // mixed_gemm_runner.moe_gemm(reinterpret_cast(d_a_int), - // reinterpret_cast((void*)d_b_int), + // reinterpret_cast((void*)d_b_int), // cutlass::epilogue::QuantMode::PerTokenChannelQuant, // reinterpret_cast(d_col_scale_half), // reinterpret_cast(d_row_scale_half), - // reinterpret_cast(d_nf4_look_up_table), + // reinterpret_cast(d_nf4_look_up_table), // reinterpret_cast(d_c_int), // reinterpret_cast(d_total_rows_before_experts), // total_rows, @@ -1224,8 +1227,10 @@ int main(int argc, char* argv[]) { // } // cudaDeviceSynchronize(); // auto stop = std::chrono::system_clock::now(); - // auto duration = std::chrono::duration_cast((stop - start)); - // std::cout<<"avg time for "<((stop - start)); + // std::cout<<"avg time for "< +get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) { + int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1]; + return {m, n, k, 1}; +} + +template +void cutlass_gemm_caller( + phi::Place device, cute::Shape prob_shape, + typename GemmKernel::MainloopArguments mainloop_args, + typename GemmKernel::EpilogueArguments epilogue_args, + typename GemmKernel::TileSchedulerArguments scheduler = {}) { + cutlass::KernelHardwareInfo hw_info; + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + phi::Allocator *allocator = paddle::GetAllocator(device); + auto workspace = allocator->Allocate(workspace_size); + + auto stream = paddle::GetCurrentCUDAStream(device)->raw_stream(); + + cutlass::Status status = gemm_op.run(args, workspace->ptr(), stream); + CUTLASS_CHECK(status); +} + +template +void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a, + paddle::Tensor const &b, + EpilogueArgs &&...epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + using GemmKernel = typename Gemm::GemmKernel; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = StrideC; + using StrideAux = StrideC; + + typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + auto [M, N, K, L] = prob_shape; + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + StrideC c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + StrideD d_stride = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + StrideAux aux_stride = d_stride; + + auto a_ptr = static_cast(const_cast(a.data())); + auto b_ptr = static_cast(const_cast(b.data())); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(const_cast(out.data())); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, d_stride}; + + cutlass_gemm_caller(a.place(), prob_shape, mainloop_args, + epilogue_args); +} + +} // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh new file mode 100644 index 000000000..26278a79f --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh @@ -0,0 +1,149 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh + +#pragma once + +// clang-format will break include orders +// clang-format off + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass_helper.h" +#include "helper.h" +// clang-format on + +/* + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp, + must contain a public type named EVTCompute of type Sm90EVT, as well as a + static prepare_args function that constructs an EVTCompute::Arguments struct. +*/ + +using namespace cute; + +namespace fastdeploy { + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + // These are the minimum alignments needed for the kernels to compile + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 4; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, + AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentAB, + ElementAB, cutlass::layout::ColumnMajor, AlignmentAB, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm100 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu new file mode 100644 index 000000000..f5d4d6aa2 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -0,0 +1,27 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu + +// clang-format will break include orders +// clang-format off +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +// clang-format on + +namespace fastdeploy { + +void cutlass_scaled_mm_azp_sm90_int8( + paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b, + paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, paddle::optional const &azp, + paddle::optional const &bias) { + if (azp) { + return cutlass_scaled_mm_sm90_int8_epilogue< + c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, + *azp, bias); + } else { + return cutlass_scaled_mm_sm90_int8_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 000000000..9a601f75a --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,34 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp + +#include "helper.h" + +template +void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a, + paddle::Tensor const &b, paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias, + Fp8Func fp8_func, Int8Func int8_func) { + PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32); + PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32); + + int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1]; + + if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) && + (b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) { + // Standard per-tensor/per-token/per-channel scaling + PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + PD_CHECK(a.dtype() == paddle::DataType::INT8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + PD_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "No kernel for this combination of input dtypes is implemented.")); + } +} diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp new file mode 100644 index 000000000..75472ea80 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,35 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp + +#pragma once + +#include "helper.h" + +namespace fastdeploy { + +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias); + +void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias); + +void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::Tensor const& azp_adj, + paddle::optional const& azp, + paddle::optional const& bias); + +void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias); + +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu new file mode 100644 index 000000000..801e90fd7 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -0,0 +1,28 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu + +// clang-format will break include orders +// clang-format off +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +// clang-format on + +namespace fastdeploy { + +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::optional const &bias) { + PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + PD_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm90_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh new file mode 100644 index 000000000..ac86aeba8 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -0,0 +1,125 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh + +#pragma once + +// clang-format will break include orders +// clang-format off +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" +// clang-format on + +/** + * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm + * shape. + */ + +namespace fastdeploy { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, + paddle::Tensor const &a, + paddle::Tensor const &b, + EpilogueArgs &&...args) { + static_assert(std::is_same()); + PD_CHECK(a.dtype() == phi::DataType::FLOAT8_E4M3FN); + PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + + uint32_t const m = a.dims()[0]; + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + +template