Compare commits

..

137 Commits

Author SHA1 Message Date
gaoziyuan
5224f6c434 support qwen3 name_mapping (#3170) 2025-08-04 16:37:40 +08:00
yinwei
bfef09dd73 update user email (#3087) 2025-07-31 14:25:31 +08:00
LokeZhou
1d46420c49 [cherry-pick][MM_PROCESS] add _extract_labels (#2879) (#2993) 2025-07-24 11:04:43 +08:00
ltd0924
fb0f284e67 [BugFix] fix prompt token ids type (#2994)
* Update serving_completion.py

* fix

* fix
2025-07-23 21:00:56 +08:00
Zero Rains
5d1788c7b5 polish code for prefill restrictions (#2992) 2025-07-23 05:12:01 -07:00
Zero Rains
abd238fc12 [Cherry-Pick][BugFix] Add prefill restrictions for chunked_prefill+VL (#2984) 2025-07-23 01:53:26 -07:00
Jiang-Jia-Jun
e5804b1d98 Revert "[LLM] fix multinode bugs (#2945)" (#2971)
This reverts commit b0f1e0eef4.
2025-07-22 21:23:48 +08:00
Sunny-bot1
8c43bc8176 [FIX 2.0.3]fix rejection sampling when topp=0 using _SAMPLING_EPS (#2966)
* fix rejection sampling when topp=0

* fix

* fix
2025-07-22 05:53:04 -07:00
ltd0924
b0f1e0eef4 [LLM] fix multinode bugs (#2945)
* [LLM] fix multinode bugs

* [LLM] fix multinode bugs

* [LLM] fix multinode bugs

* [LLM] fix ci bugs

* fix ci bugs

* fix ci bugs
2025-07-22 20:23:37 +08:00
ming1753
69be77c8c0 [Feature] support prompt repetition_penalty (#2954)
* [Feature] support prompt repetition_penalty (#2806)

* [Bug Fix] fix bug of prompt penalty (#2888)
2025-07-22 19:42:33 +08:00
gaoziyuan
535a15ab8f [Fix]Fix vl when import fastdeploy and fix rl config rank bug (#2953)
* support vl ori_vacab_size

* support trainer_degree in name_mapping

* fix

* fix import error

* fix local rank
2025-07-22 04:40:27 -07:00
sg263
580460046f merge 2.0.2 into 2.0.3 (#2917)
Co-authored-by: shige <shige@baidu.com>
2025-07-22 14:46:20 +08:00
Sunny-bot1
4dbc483713 [BugFix]Fix sample rejection (#2908) (#2949)
* fix config

* fix rejection

Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
2025-07-22 13:39:34 +08:00
gaoziyuan
4ead15822c 【Sync develop】support vl model name_mapping and ori_vocab_size (#2915)
* support vl ori_vacab_size

* support trainer_degree in name_mapping

* fix
2025-07-20 23:14:15 -07:00
Jiang-Jia-Jun
f941124402 [Feature] Support include_stop_str_in_output (#2930)
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-21 10:58:32 +08:00
RAM
b89f083004 [Executor] Fix set capture sizes bug (#2903) 2025-07-18 10:58:05 +08:00
RAM
4d05ed596c Update GraphOptimizationBackend docs (#2899) 2025-07-17 21:41:38 +08:00
ltd0924
bc1866af58 [LLM] delete fixed slot (#2894) 2025-07-17 20:44:55 +08:00
yulangz
fe237fe92b [XPU][doc] pick xpu doc fix (#2897) 2025-07-17 20:01:40 +08:00
YUNSHEN XIE
3a480abcbb enable CI workflow for pull requests targeting release/* branches (#2886)
* enable CI workflow for pull requests targeting release/* branches

* fix
2025-07-17 16:48:13 +08:00
Yuanle Liu
335609efb6 fix rollout_model and add rl ut (#2882) 2025-07-17 13:37:54 +08:00
Jiang-Jia-Jun
3464f75f98 Update setup.py 2025-07-16 23:45:05 +08:00
Jiang-Jia-Jun
09d0073fdc [Sync Code] develop to release/2.0.3 (#2873)
* [LLM] support send batch data and aggregate data (#2860)

* [LLM] support send batch data and aggregate data

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] fix ci bugs

* [LLM] update

* [LLM] Update Multinode Deployment (#2830)

* [LLM] fix multinode bugs

* [LLM] update multinode deployment

* [LLM] update multinode deployment

* [LLM] update multinode deployment

* [LLM] update multinode deployment

* [LLM] update multinode deployment

* [LLM] fix ci bugs

* Update fastdeploy/engine/args_utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [LLM] update random port

* [LLM] update random port

* [LLM] fix ci bugs

* fix ci bugs

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-16 23:44:26 +08:00
Yuanle Liu
63d6e7ce06 fix and refine vl (#2866)
* refine vl config

* delete attn_sep

* fix vl accuracy
2025-07-16 05:59:28 -07:00
周周周
aa76085d1f [Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
2025-07-16 20:10:57 +08:00
sg263
42b80182e0 [Trace] add opentelemetry (#2852)
* add opentelemetry

* add opentelemetry

* add opentelemetry on dequeue

* add opentelemetry on dequeue

* add opentelemetry on dequeue

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-16 15:33:25 +08:00
Yuanle Liu
dda4a9f848 rl update (#2861) 2025-07-16 00:33:10 -07:00
yangjianfengo1
a83a3eea5f 将FLAGS_max_partition_size修改为环境变量获取 (#2854) 2025-07-16 14:14:21 +08:00
xiaoxiaohehe001
0d0340392f [Fix] Fix mm ep weight init. (#2855)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix_45t_mm

* Update load_weight_utils.py

* Update load_weight_utils.py
2025-07-16 12:02:39 +08:00
YuanRisheng
0253381fb9 fix config (#2858) 2025-07-16 11:40:10 +08:00
freeliuzc
2d1184aefe [Fix] fix expert_parallel bug in decoder stage (#2848) 2025-07-16 11:08:18 +08:00
yulangz
17314ee126 [XPU] Update doc and add scripts for downloading dependencies (#2845)
* [XPU] update xvllm download

* update supported models

* fix xpu model runner in huge memory with small model

* update doc
2025-07-16 11:05:56 +08:00
YuanRisheng
101ad33332 [BugFix] Fix Configs (#2849)
* fix config

* fix config
2025-07-15 19:50:36 -07:00
RAM
0fad10b35a [Executor] CUDA Graph support padding batch (#2844)
* cuda graph support padding batch

* Integrate the startup parameters for the graph optimization backend and provide support for user - defined capture sizes.

* Do not insert max_num_seqs when the user specifies a capture list

* Support set graph optimization config from YAML file

* update cuda graph ci

* fix ci bug

* fix ci bug
2025-07-15 19:49:01 -07:00
Yuanle Liu
61b3997b85 refactor rl get_name_mappings_to_training (#2847)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
2025-07-15 07:31:42 -07:00
Zero Rains
e7bcbbab52 Merge vl execution path into normal execution path (#2829)
* merge vl model into gpu_model runner

Change-Id: I9f4691a3d5f135e8d72b1d58abcd15ef3aa3f2a6

* fix chinese

Change-Id: Ic7405109b984c21e076fb3b01ff6feb571d0119a

* fix the parse parameter

Change-Id: I4cd62ee87c06220af580d91e347145d4394917fe

* fix the bug in online_inference

Change-Id: Idb111bb2114e83017c4050b2a68cf039c6d3c559

* polish code

Change-Id: I7d4194102c2f1b0743b74fbd5fc284eb8ef4d17c
2025-07-15 22:20:03 +08:00
zhenwenDang
5fc659b900 [Docs] add enable_logprob parameter description (#2850)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* add enable_logprob parameter description

* add enable_logprob parameter description

* add enable_logprob parameter description

* add enable_logprob parameter description

* add enable_logprob parameter description

* add enable_logprob parameter description

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-15 19:47:45 +08:00
ophilia-lee
33db137d0b 新增vLLM默认请求参数yaml 2025-07-15 19:31:27 +08:00
lijingning
9d6a42b334 适配vLLM无arrival_time;适配vLLM model必传;RequestFuncInput/RequestFuncOutput/SampleRequest新增用例编号no 2025-07-15 19:31:27 +08:00
Jiang-Jia-Jun
1b712bba82 Update setup.py 2025-07-15 14:57:23 +08:00
AIbin
fd91da7b41 【Inference Optimize】Support wint2 triton kernel about triton_utils_v2 (#2842)
* update supported_models doc
2025-07-15 14:35:40 +08:00
bukejiyu
15c8c240b5 [vl] Use top_k from config.json (#2831)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-15 00:39:12 +08:00
freeliuzc
7cdd8d290d [MTP] optimize mtp infer speed (#2840)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-14 19:50:22 +08:00
YuanRisheng
4c7b8bc458 Simplify the Config code (#2770)
* simplify the code

* fix vl

* delete config

* fix

* perfect code

* fix ci

* fix xpu

* fix xpu

* fix server

* resolve conflict

* fix mtp

* resolve conflict

* fix xpu

* fix xpu

* fix vl

* fix log

* fix qwen moe

* fix qwen moe

* fix qwen moe
2025-07-14 19:50:05 +08:00
freeliuzc
2e81792d64 [fix] fix 'force-reinstall all-depe-packages in build' (#2837) 2025-07-14 16:50:54 +08:00
AIbin
b7858c22d9 【Update Docs】update supported_models doc (#2836)
* update supported_models doc
2025-07-14 16:01:34 +08:00
GoldPancake
09bbac6de0 Add DeepGEMM pre-compile tools (#2819)
This tool allows you to compile all possible kernels in advance through the model's config.json, and avoids the situation where uncompiled kernel is encountered and JIT is executed when certain requests arrive.
2025-07-14 14:56:41 +08:00
freeliuzc
7f64d408a9 [MTP] support expert-parellel in mtp (#2835) 2025-07-14 14:28:50 +08:00
lddfym
ece88596ed fix spelling error (#2827) 2025-07-14 13:12:57 +08:00
bukejiyu
bad53c6b6e [vl]remove duplicated load logic (#2744)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-13 07:36:26 +08:00
xiegegege
16940822a7 add result save for ci (#2824)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
LGTM
2025-07-12 23:34:46 +08:00
zhenwenDang
d48c03413f Feature/logprob bug fix (#2817)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix: handle missing logprobs at step 0 and incorrect finish reason with max_completion_tokens

* Prevent response_logprobs.logprob_token_ids[0] from going out of bounds
2025-07-12 16:48:51 +08:00
gaoziyuan
e9e8443ea8 fix num_blocks_local when small size model in TP2 running mode (#2792) 2025-07-12 12:50:48 +08:00
gaoziyuan
749b2e9c89 support qwen3moe name_mapping (#2820) 2025-07-12 12:05:54 +08:00
Sunny-bot1
f6ad26fc08 fix topp default value (#2814)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-11 17:10:21 +08:00
zhink
c08561c13a [Feature] support tensor-parallel-size>num_key_value_heads for qwen3 (#2799) 2025-07-11 15:09:43 +08:00
chen
2c3607407f check (#2811) 2025-07-11 13:54:52 +08:00
lddfym
b5e4288704 Global scheduler supports configuring hot updates (#2807)
* Check if the controller port is available

* Global scheduler supports configuring hot updates

* add interface: /controller/scheduler

* add interface: /controller/scheduler
2025-07-11 13:38:07 +08:00
yulangz
abbbd0cddc [XPU] Update docker file (#2809) 2025-07-11 13:26:38 +08:00
yinwei
e98937cbba delete useless file (#2772)
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-11 11:46:04 +08:00
Sunny-bot1
240d6236bc [Fix]fix top_k_top_p sampling (#2801)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix topk-topp

* update

* add base_non_truncated
2025-07-10 22:35:10 +08:00
littledgg
59071268b6 [Executor] Move forward_meta.py to fastdeploy/model_executor (#2774)
* Use PEP 563 in attention.py and fix conflict

* merge commit

* Change what was left out last time
2025-07-10 20:36:51 +08:00
lizexu123
8c660a0dfb [BugFix] fix RMSNorm rms_norm_esp (#2797)
* fix rms

* add vl

* fix

* add vl

* fix

* fix
2025-07-10 20:02:24 +08:00
LiqinruiG
ce5adec877 [Doc] modify offline-inerence docs (#2800)
* modify offline-inerence docs

* [bug] remove tool_call_content
2025-07-10 19:41:12 +08:00
Zeyu Chen
36571fd2d9 Update README.md
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-10 17:01:08 +08:00
yulangz
830de5a925 [XPU] Supports TP4 deployment on 4,5,6,7 (#2794)
* 支持通过 XPU_VISIBLE_DEVICES 指定 4,5,6,7 卡运行
* 修改 XPU 文档中多卡说明
2025-07-10 16:48:08 +08:00
chen
d33105baeb [Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs

* check xpu

* check vl_gpu_model_runner and xpu_model_runner

* get_worker() check platform
2025-07-10 16:33:40 +08:00
K11OntheBoat
24f934f1f9 [BugFix] Fix low prediction accuracy of deepseekv3 (#2798) 2025-07-10 16:16:44 +08:00
Sunny-bot1
1e2319cbef Rename top_p_sampling to top_k_top_p_sampling (#2791) 2025-07-10 00:09:25 -07:00
Sunny-bot1
e45050cae3 [Feature] support top_k_top_p sampling (#2753)
* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
2025-07-09 20:58:58 -07:00
Ryan
b0f525955c [SOT] Remove breakgraph in post processing && fix datatype (#2780) 2025-07-10 11:26:00 +08:00
Yuanle Liu
2ea267f624 assert prompt len > 0 (#2773) 2025-07-10 11:14:52 +08:00
0x3878f
1d8af7ab73 Add env variable for dy2st (#2779) 2025-07-10 11:06:06 +08:00
LiqinruiG
54affdc44b [Doc] modify offline_inference docs (#2787)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* modify reasoning_output docs

* modify offline inference docs

* modify offline inference docs

* modify offline_inference docs

* modify offline_inference docs
2025-07-10 01:06:14 +08:00
Jiang-Jia-Jun
a4fdb3970b [BugFix] Fix vocab size error for ernie model (#2785)
* [BugFix] Fix vocab size error for ernie model

* [BugFix] Fix vocab size error for ernie model

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-10 01:05:51 +08:00
Jiang-Jia-Jun
2a86928657 [BugFix Revert] Fix vocab size error for ernie model 2025-07-09 22:14:54 +08:00
Jiang-Jia-Jun
b1c53fa779 [BugFix] Fix vocab size error for ernie model 2025-07-09 22:13:41 +08:00
lizexu123
da20cf681e [Bug fix] Fixed the garbled text issues in Qwen3-8B (#2783) 2025-07-09 22:03:57 +08:00
LiqinruiG
4ccd1696ab [Doc] modify offline inference docs (#2747)
* modify reasoning_output docs

* modify offline inference docs

* modify offline inference docs
2025-07-09 20:53:26 +08:00
chen
888780ffde [Feature] block_wise_fp8 support triton_moe_backend (#2767) 2025-07-09 19:22:47 +08:00
RAM
e3768c5a83 [Executor] Fix bug of logger.debug (#2778) 2025-07-09 04:13:43 -07:00
lifulll
1f28bdf994 dcu adapter ernie45t (#2756)
Co-authored-by: lifu <lifu@sugon.com>
Co-authored-by: yongqiangma <xing.wo@163.com>
2025-07-09 18:56:27 +08:00
RAM
03a74995b8 Clear dead code And supplementary notes (#2757)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* 1.supplementary notes 2.delete dead code

* fix bug of forward meta

* Global modification of forward meta

* fix vl model_runner bug
2025-07-09 16:17:34 +08:00
zhink
b89180f1cd [Feature] support custom all-reduce (#2758)
* [Feature] support custom all-reduce

* add vllm adapted
2025-07-09 16:00:27 +08:00
yulangz
be21ef5047 [XPU] Supports BF16 for ERNIE-4.5-21B-A3B and ERNIE-4.5-0.3B (#2765)
* fix no quant xpu moe

* change dir of xpu moe weight only
2025-07-09 15:57:51 +08:00
celsowm
771e71a24d Feat/blackwell sm100 support (#2670)
* Add initial support for NVIDIA Blackwell (SM100) architecture

This change introduces initial support for the NVIDIA Blackwell GPU
architecture, specifically targeting SM100 (Compute Capability 10.x)
with '100a' architecture-specific features (e.g., for CUTLASS).

Key changes:
- Updated custom_ops/setup_ops.py to generate appropriate gencode
  flags (arch=compute_100a,code=sm_100a) when '100' is specified
  in FD_BUILDING_ARCS. Requires CUDA 12.9+.
- Updated custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h:
    - Added CutlassTileConfigSM100 enum (with placeholder tile shapes).
    - Added BLACKWELL to CandidateConfigTypeParam.
    - Updated CutlassGemmConfig struct with is_sm100 flag,
      tile_config_sm100, and new constructor for SM100.
    - Modified toString() and fromString() for SM100 support.
- Updated custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu:
    - Added get_candidate_tiles_sm100() (with placeholder tiles).
    - Added placeholder mcast support functions for SM100.
    - Updated get_candidate_configs() to include SM100 paths using
      the BLACKWELL flag and new SM100 config types.
- Updated build.sh with comments to guide users on specifying '100'
  for Blackwell in FD_BUILDING_ARCS.

Further work:
- Optimal CUTLASS tile configurations for SM100 need to be researched
  and updated in cutlass_heuristic.cu.
- Kernel auto-generation scripts in custom_ops/utils/ may need
  SM100-specific versions if Blackwell's hardware features for FP8/TMA
  differ significantly from SM90.
- Compatibility of third-party libraries (CUTLASS v3.8.0, DeepGEMM)
  with Blackwell should be fully verified.

* Feat: Implement detailed Blackwell (SM100) CUTLASS heuristics

This change integrates specific, expert-provided CUTLASS heuristic
configurations for the NVIDIA Blackwell (SM100) GPU architecture,
replacing previous placeholders. This includes:

- Updated `custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h`:
    - Populated `CutlassTileConfigSM100` enum with specific tile shapes
      (e.g., CtaShape64x64x128B, CtaShape128x128x128B) suitable for SM100.
    - Added `FP4_ONLY` to `CandidateConfigTypeParam` for new FP4 paths.

- Updated `custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu`:
    - Implemented `get_candidate_tiles_sm100` with detailed logic for
      selecting tile configurations based on GROUPED_GEMM and FP4_ONLY flags,
      using the new SM100 tile enums.
    - Implemented `supports_mcast_along_m_sm100` and
      `supports_mcast_along_n_sm100` with specific tile checks for Blackwell.
    - Updated the `sm == 100` (Blackwell) block in `get_candidate_configs`
      to use these new helper functions and accurately populate candidate
      kernel configurations for various cluster shapes.

- `custom_ops/setup_ops.py` remains configured to compile for
  `arch=compute_100a,code=sm_100a` with CUDA 12.9+ for these features.

This aligns the codebase with heuristic configurations similar to those
in upstream TensorRT-LLM / CUTLASS for Blackwell, enabling more
performant kernel selection on this new architecture.

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-09 15:29:42 +08:00
yulangz
0350831c2b fix xpu offline demo garbled output (#2763) 2025-07-09 14:51:20 +08:00
RichardWooSJTU
fee544e808 fix ep prefill (#2762) 2025-07-09 14:03:05 +08:00
Ryan
c4718fd693 Enable SOT D2St in Multimodal Model (#2735) 2025-07-09 12:26:18 +08:00
GoldPancake
f7cad30a38 [Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark

* Fix the name of the parameter
2025-07-09 12:08:43 +08:00
gaoziyuan
6b10c19482 【Feature】add fd commit/branch info when start server (#2752)
* add_commit_config

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-09 11:52:22 +08:00
EnflameGCU
f4f1d8de44 Support for non-CUDA builds (#2750)
Co-authored-by: yongqiangma <xing.wo@163.com>
2025-07-09 11:48:40 +08:00
RichardWooSJTU
6610aa29d0 Revert "[Bug fix] fix attention rank init (#2743)" (#2761)
This reverts commit e8bbe7244b.
2025-07-09 10:38:12 +08:00
Ryan
f72c4de539 [SOT] Make custom_op dy&st unified (#2733)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* make_custom_op dy&st unified

* add instance judgement
2025-07-08 19:21:44 +08:00
xiegetest
f6ffbc3cbd add precision check for ci (#2732)
* add precision check for ci

* add precision check for ci

* add precision check for ci

* add precision check for ci

---------

Co-authored-by: xiegegege <xiege01@baidu.com>
2025-07-08 18:43:53 +08:00
RichardWooSJTU
e8bbe7244b [Bug fix] fix attention rank init (#2743)
* fix attention rank init

* fix attention rank init
2025-07-08 17:19:49 +08:00
Longzhi Wang
57b086dc6b [Bug fix] Add the missing pod_ip param to the launch_cache_manager function. (#2742)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Bug fix] fix the missing position args in expert_service.py

* update
2025-07-08 14:52:13 +08:00
lizexu123
525be243e7 [Bug fix] Fixed the garbled text issues in Qwen3-8B (#2737)
* fix qwen3.py

* update

* update lm_head tie_word_embeddings

* update tie_word_embeddings

* fix

* fix tie_word_embedding not in config.json

---------

Co-authored-by: lizexu <lizexu@baidu.com>
2025-07-07 23:15:27 -07:00
EnflameGCU
d0f4d6ba3a [GCU] Support gcu platform (#2702)
baseline: e7fa57ebae

Co-authored-by: yongqiangma <xing.wo@163.com>
2025-07-08 13:00:52 +08:00
gaoziyuan
26d5d737dd 【Fearture】support qwen2 some func (#2740)
* add rl qwen model support

* fix

* fix
2025-07-08 12:03:04 +08:00
Ryan
fefbd65cf8 [SOT] Remove BreakGraph with paddle.maximum (#2731)
* rm if with clip

* clip -> maximum

* int64 -> int32
2025-07-08 11:44:25 +08:00
ming1753
1eb8ea7328 [Bug fix] fix complie bug when sm < 89 (#2738) 2025-07-08 11:24:52 +08:00
ming1753
ef6649a577 [Optimize] Optimize tensorwise fp8 performance (#2729)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [Optimize] Optimize tensorwise fp8 performance
2025-07-07 20:06:28 +08:00
liddk1121
1b54a2831e Adapt for iluvatar gpu (#2684) 2025-07-07 16:53:14 +08:00
YUNSHEN XIE
2579e8fea8 support FastDeploy version setting (#2725)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-07 14:50:11 +08:00
Yuanle Liu
91528f1af9 remove redundant install whl of fastdeploy (#2726)
* remove redundant install

* remove redundant install
2025-07-06 23:49:37 -07:00
lddfym
4e293e50fa Check if the controller port is available (#2724) 2025-07-07 13:24:55 +08:00
chen
66b321d9ec Update eb45-0.3B cuda memory (#2686) 2025-07-07 11:31:15 +08:00
ltd0924
68b4755587 [LLM] support multi node deploy (#2708)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* [LLM] support multi node deploy

* Update engine.py

* fix bugs

* fix

* [LLM] support multi node deploy

* [LLM] support multi node deploy

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-06 10:33:51 +08:00
LQX
04a8e1ef2b 修改XPU CI, test=model (#2721) 2025-07-06 10:19:04 +08:00
Ting
a6e9161045 fix bug. (#2718)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-05 08:19:19 +08:00
Ting
90ef28d982 spec token map lazy. (#2715)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-05 00:14:54 +08:00
YuBaoku
b37585e693 [BugFix] fix paddle_git_commit_id error (#2714)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* set git identity to avoid merge failure in CI

* add ci cases

* [CI] Add validation for MTP and CUDAGraph

* [BugFix] fix paddle_git_commit_id error
2025-07-04 22:16:37 +08:00
lizexu123
9cb08e71e8 add support QWQ enable_thinking (#2706)
* add support QWQ enable_thinking

* add stream=True

* fix stream=true

* fix qwen

---------

Co-authored-by: lizexu <lizexu@baidu.com>
2025-07-04 20:55:23 +08:00
YuBaoku
dacc46f04c [CI] Add validation for MTP and CUDAGraph (#2710)
* set git identity to avoid merge failure in CI

* add ci cases

* [CI] Add validation for MTP and CUDAGraph
2025-07-04 18:13:54 +08:00
Jiang-Jia-Jun
09ded7715f Update mkdocs.yml 2025-07-04 17:55:52 +08:00
LQX
11cfdf5d89 添加XPU CI, test=model (#2701)
* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model

* 添加XPU CI,  test=model
2025-07-04 16:16:06 +08:00
GoldPancake
e7fa57ebae Extract eh_proj Layer from ParallelLMHead for MTP to Avoid Weight Transposition Issue (#2707)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* fix mtp eh_proj layer

* fix mtp update_cfg function

* fix stringdoc

* simplify class name
2025-07-04 14:15:04 +08:00
gaoziyuan
a5ae88ded9 [feature]add fd whl version info (#2698) 2025-07-04 14:12:42 +08:00
ltd0924
87e638498c [RL] update reschedule finish reason (#2709) 2025-07-04 13:47:36 +08:00
freeliuzc
667547be59 support chunk_prefill in MTP (#2705) 2025-07-04 11:55:48 +08:00
LiqinruiG
b38823bc66 modify reasoning_output docs (#2696) 2025-07-04 11:30:02 +08:00
Divano
050d9658a5 Update requirements.txt 2025-07-04 09:53:03 +08:00
Divano
be5cabaf80 add quick benchmark (#2703)
测试脚本不需要过CI
2025-07-04 09:32:36 +08:00
Yuanle Liu
240bdac2a4 [feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* support fa3 backend run in pd disaggregated

* delete use_fast_ffn
2025-07-03 22:33:27 +08:00
ltd0924
00863c43fd [Bug] fix logger format (#2689)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-03 19:58:03 +08:00
kevin
3d3bccdf79 [doc] update docs (#2690) 2025-07-03 19:33:19 +08:00
Jiang-Jia-Jun
9fd74f75bd Update dynamic_weight_manager.py 2025-07-03 15:55:22 +08:00
Jiang-Jia-Jun
05c670e593 [Sync] Update to latest code (#2679)
* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-03 15:43:53 +08:00
Jiang-Jia-Jun
d222248d00 Update README.md 2025-07-03 15:28:28 +08:00
Jiang-Jia-Jun
e5b94d4117 Update README.md 2025-07-03 15:28:05 +08:00
Jiang-Jia-Jun
87e2e58a22 Update gh-pages.yml 2025-07-03 15:26:21 +08:00
Jiang-Jia-Jun
de20e5a992 Update Dockerfile.xpu
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-03 10:14:50 +08:00
Jiang-Jia-Jun
2f9c0618f0 Update Dockerfile.gpu 2025-07-03 10:14:39 +08:00
Yuanle Liu
9a14ab6572 add --force-reinstall --no-cache-dir when pip install fastdeploy*.whl (#2682)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
2025-07-02 05:32:20 -07:00
Divano
d1cb3ed571 Update gh-pages.yml (#2680) 2025-07-02 17:36:18 +08:00
handiz
b8a8a19689 add wint2 performance (#2673) 2025-07-02 17:10:01 +08:00
351 changed files with 31221 additions and 8793 deletions

View File

@@ -2,7 +2,9 @@ name: CI
on:
pull_request:
branches: [ develop ]
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
@@ -27,9 +29,11 @@ jobs:
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
@@ -38,7 +42,7 @@ jobs:
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME}
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}

87
.github/workflows/ci_xpu.yml vendored Normal file
View File

@@ -0,0 +1,87 @@
name: CI_XPU
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}-xpu-ci
cancel-in-progress: true
jobs:
CI_XPU:
runs-on: [self-hosted, XPU-P800-8Card]
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
# Because the system version is lower than 2.23, the checkout cannot be used.
# - name: Checkout code
# uses: actions/checkout@v4
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [[ "$last_char" =~ [0-3] ]]; then
gpu_id="$last_char"
else
gpu_id="0"
fi
FD_API_PORT=$((9180 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
FD_METRICS_PORT=$((9170 + gpu_id * 100))
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "/ssd3:/ssd3" \
-e "MODEL_PATH=/ssd3/model" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_xpu.sh
"

View File

@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
on:
push:
branches: [ develop ]
pull_request:
branches: [ develop ]
permissions:
contents: write
@@ -21,4 +19,6 @@ jobs:
- name: Deploy to GitHub Pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: mkdocs gh-deploy --force --remote-name origin
run: |
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
mkdocs gh-deploy --force --remote-name origin

2
.gitignore vendored
View File

@@ -162,3 +162,5 @@ custom_ops/tmp*
build
.ccls-cache
third_party

View File

@@ -5,12 +5,6 @@ default_stages:
- pre-commit # Run locally
# - manual # Run in CI
repos:
# 格式化
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
# 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
@@ -29,15 +23,6 @@ repos:
rev: 6.0.1
hooks:
- id: isort
# # 格式化
# - repo: https://github.com/pre-commit/mirrors-clang-format
# rev: v20.1.3
# hooks:
# - id: clang-format
# # exclude: '.*'
# types_or: [c++, cuda]
# args: [--style=file, --verbose]
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29

View File

@@ -8,14 +8,17 @@
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
</p>
--------------------------------------------------------------------------------

View File

@@ -105,3 +105,30 @@ python benchmark_serving.py \
--save-result > infer_log.txt 2>&1 &
```
### 投机解码性能测试工具
#### 使用方式:
```bash
python benchmarks/benchmark_mtp.py \
--host 127.0.0.1 --port 8000 \
--max-concurrency 16 32 64 96 --num-prompts 256 \
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
--s_itl-base-model 15.88 22.84 16.47 16.93 \
--dataset-name EBChat \
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
```
#### 参数说明
```bash
--host服务ip地址用于组url
--port服务HTTP端口用于组url
--max-concurrency测试并发数
--num-prompts总计发送多少条请求
--acceptance-rate投机解码的模拟接受率
--draft-token-steps投机解码的步数
--s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应
--dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集
--dataset-path测试数据集路径
```

View File

@@ -36,6 +36,7 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
"""Input for requesting LLMs via API"""
no: int
prompt: str
history_QA: Optional[dict]
hyper_parameters: dict
@@ -54,6 +55,7 @@ class RequestFuncInput:
@dataclass
class RequestFuncOutput:
"""Output for requesting LLMs via API"""
no: int = 0
generated_text: str = ""
reasoning_content: str = ""
success: bool = False
@@ -84,7 +86,7 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": "default",
"model": request_func_input.model,
"messages": request_func_input.history_QA,
"stream": True,
"stream_options": {
@@ -97,6 +99,9 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
@@ -104,6 +109,7 @@ async def async_request_eb_openai_chat_completions(
output = RequestFuncOutput()
output.prompt_len = 0
output.no = request_func_input.no
ttft = 0.0
st = time.perf_counter()
@@ -132,7 +138,8 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
# Decoding phase
else:
@@ -141,12 +148,12 @@ async def async_request_eb_openai_chat_completions(
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time"))
elif usage := data.get("usage"):
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.output_tokens = usage.get(
"completion_tokens")
"completion_tokens", 0)
output.prompt_tokens = usage.get(
"prompt_tokens")
"prompt_tokens", 0)
most_recent_timestamp = timestamp
@@ -173,6 +180,7 @@ async def async_request_eb_openai_chat_completions(
f.write(str(output) + "\n")
if pbar:
pbar.update(1)
print("#####final_output:", output)
return output
@@ -189,7 +197,7 @@ async def async_request_eb_openai_completions(
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": "default",
"model": request_func_input.model,
"prompt": request_func_input.prompt,
"stream": True,
"stream_options": {
@@ -202,14 +210,20 @@ async def async_request_eb_openai_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:", json.dumps(payload, ensure_ascii=False))
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output.no = request_func_input.no
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
@@ -226,6 +240,7 @@ async def async_request_eb_openai_completions(
"data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter()
data = json.loads(chunk)
# NOTE: Some completion API might have a last
@@ -235,21 +250,22 @@ async def async_request_eb_openai_completions(
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += text or ""
most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time"))
generated_text += text or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage"):
output.prompt_tokens = usage.get(
"prompt_tokens")
@@ -262,8 +278,15 @@ async def async_request_eb_openai_completions(
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
if output.generated_text == "":
output.success = False
output.error = "No generated text found!"
else:
output.success = True
else:
output.error = response.reason or ""
output.success = False
@@ -271,6 +294,8 @@ async def async_request_eb_openai_completions(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output))
if pbar:
pbar.update(1)

View File

@@ -38,7 +38,7 @@ class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
no: int
prompt: Union[str, Any]
history_QA: Union[str, Any]
json_data: Optional[dict]
@@ -229,6 +229,7 @@ class EBDataset(BenchmarkDataset):
**kwargs,
) -> list:
samples: list = []
cnt = 1
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -246,16 +247,17 @@ class EBDataset(BenchmarkDataset):
prompt, None)
samples.append(
SampleRequest(
no=cnt,
prompt=prompt,
prompt_len=self.prompt_len,
history_QA=[],
expected_output_len=new_output_len,
))
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples
class EBChatDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
@@ -284,6 +286,7 @@ class EBChatDataset(BenchmarkDataset):
**kwargs,
) -> list:
samples: list = []
cnt = 1
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -297,12 +300,14 @@ class EBChatDataset(BenchmarkDataset):
prompt, None)
samples.append(
SampleRequest(
no=cnt,
json_data=json_data,
prompt=prompt,
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
))
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples

191
benchmarks/benchmark_mtp.py Normal file
View File

@@ -0,0 +1,191 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union
import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark
def prepare_input_requests(
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
}
try:
input_requests = dataset_mapping[dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {dataset_name}") from err
return input_requests
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
selected_percentile_metrics = ["s_itl"]
selected_percentiles = []
# Run benchmark
results = asyncio.run(
benchmark(
backend="openai-chat",
api_url=f"{base_url}/v1/chat/completions",
base_url=base_url,
model_id="default",
model_name="default",
input_requests=input_requests,
hyper_parameters={},
logprobs=None,
request_rate=float("inf"),
burstiness=1.0,
disable_tqdm=disable_tqdm,
profile=False,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
ignore_eos=False,
goodput_config_dict=None,
max_concurrency=max_concurrency,
lora_modules=None,
extra_body=None,
)
)
record = {
"mean_s_itl_ms": results["mean_s_itl_ms"],
}
return record
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
tmp = 0.0
for i in range(draft_token_step):
tmp += pow(acceptance_rate, i + 1)
r_ac = tmp / (1 + tmp)
return t_ori / ((1 - r_ac) * t_mtp)
def main(args):
base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests(
args.num_prompts, args.dataset_name, args.dataset_path
)
if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup
print("Starting warmup...")
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
# Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup(
args.acceptance_rate,
draft_token_step,
s_itl,
record["mean_s_itl_ms"],
)
print(
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("=" * 50)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
)
parser.add_argument(
"--port",
type=str,
default="8000",
)
parser.add_argument(
"--max-concurrency",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16, 32),
)
parser.add_argument(
"--num-prompts",
type=int,
default=128,
)
parser.add_argument(
"--acceptance-rate",
type=float,
default=0.8,
)
parser.add_argument(
"--draft-token-steps",
type=int,
nargs="+",
default=(1, 2),
)
parser.add_argument(
"--s_itl-base-model",
type=float,
nargs="+",
)
parser.add_argument(
"--dataset-name",
type=str,
default="EBChat",
)
parser.add_argument(
"--dataset-path",
type=str,
)
args = parser.parse_args()
main(args)

View File

@@ -182,6 +182,7 @@ def calculate_metrics(
# len(outputs[i].itl) since multiple output tokens may be
# bundled together
# Note : this may inflate the output token count slightly
continue
actual_output_lens.append(output_len)
input_lens.append(outputs[i].prompt_len)
@@ -209,6 +210,8 @@ def calculate_metrics(
if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) /
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
else:
print("len(outputs[i].arrival_time) <= 2")
completed += 1
else:
actual_output_lens.append(0)
@@ -341,15 +344,16 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_output_len = \
test_prompt, test_output_len, test_no = \
input_requests[0].prompt, \
input_requests[0].expected_output_len
input_requests[0].expected_output_len, input_requests[0].no
test_history_QA = input_requests[0].history_QA
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
no=test_no,
prompt_len=0,
history_QA=test_history_QA,
hyper_parameters=hyper_parameters,
@@ -384,6 +388,7 @@ async def benchmark(
profile_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=test_prompt,
no=test_no,
api_url=base_url + "/start_profile",
output_len=test_output_len,
logprobs=logprobs,
@@ -422,7 +427,7 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, output_len = request.prompt, request.expected_output_len
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
history_QA = request.history_QA
req_model_id, req_model_name = model_id, model_name
@@ -433,6 +438,7 @@ async def benchmark(
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt,
no=no,
prompt_len=0,
history_QA=history_QA,
hyper_parameters=hyper_parameters,
@@ -452,6 +458,7 @@ async def benchmark(
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
no=test_no,
api_url=base_url + "/stop_profile",
output_len=test_output_len,
logprobs=logprobs,
@@ -464,6 +471,8 @@ async def benchmark(
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
print("benchmark_duration:", benchmark_duration)
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
@@ -594,6 +603,155 @@ async def benchmark(
return result
def benchmark_metrics(
benchmark_duration: float,
result_file: str,
selected_percentiles: list[float],
selected_percentile_metrics: list[str],
goodput_config_dict: dict[str, float],
):
"""Benchmark metrics statisticsgenerate benchmark result"""
outputs = []
case_no_list = []
with open(result_file) as f:
for line in f.readlines():
if "RequestFuncOutput" in line:
start = line.find("RequestFuncOutput")
end = line.rfind(")")
para_str = line[start:end + 1]
output = eval(para_str)
outputs.append(output)
input_requests = [[]] * len(outputs)
goodput_config_dict = check_goodput_args(args)
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"input_texts": ["" for input in input_requests],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
process_one_length("input_len", "Input Length", "Input Length")
process_one_length("s_input_len", "Input Length", "Infer Input Length")
process_one_length("output_len", "Output Length", "Output Length")
print("=" * 50)
return result
def check_goodput_args(args):
"""Check whether the given argument has valid goodput configuration or not"""
# Check and parse goodput arguments
@@ -759,6 +917,16 @@ def main(args: argparse.Namespace):
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
# benchmark_result = benchmark_metrics(
# benchmark_duration=3600,
# result_file="your result file",
# selected_percentile_metrics=args.percentile_metrics.split(","),
# selected_percentiles=[
# float(p) for p in args.metric_percentiles.split(",")
# ],
# goodput_config_dict=goodput_config_dict,
# )
# Save config and results to json
if args.save_result:

File diff suppressed because it is too large Load Diff

View File

@@ -3,3 +3,4 @@ tqdm
numpy
Pillow
pyyaml
requests

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 96
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wfp8afp8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -0,0 +1,3 @@
metadata:
min_tokens: 32
max_tokens: 33

View File

@@ -0,0 +1,11 @@
top_p: 1.0
temperature: 1.0
metadata:
min_tokens: 1
max_tokens: 30721
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0
skip_special_tokens: false
chat_template_kwargs:
enable_thinking: true

View File

@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
PYTHON_VERSION=${2:-"python"}
export python=$PYTHON_VERSION
FD_CPU_USE_BF16=${3:-"false"}
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
# These will be translated to 90a / 100a in setup_ops.py for specific features.
FD_BUILDING_ARCS=${4:-""}
@@ -74,8 +77,10 @@ function copy_ops(){
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy"
echo -e "BASE and ROCM ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
@@ -104,6 +109,23 @@ function copy_ops(){
return
fi
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
if [ "$is_gcu" = "True" ]; then
DEVICE_TYPE="gcu"
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
echo -e "gcu ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../
@@ -163,17 +185,24 @@ function build_and_install() {
exit 1
fi
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
}
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
cd $DIST_DIR
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
if [ $? -ne 0 ]; then
cd ..
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
exit 1
function version_info() {
output_file="fastdeploy/version.txt"
fastdeploy_git_commit_id=$(git rev-parse HEAD)
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
cuda_version="nvcc-not-installed"
if command -v nvcc &> /dev/null; then
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
fi
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
cd ..
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
echo "Paddle version: $paddle_version" >> $output_file
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
echo "CUDA version: $cuda_version" >> $output_file
echo "CXX compiler version: $cxx_version" >> $output_file
}
function cleanup() {
@@ -207,6 +236,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
set -e
init
version_info
build_and_install_ops
build_and_install
cleanup
@@ -237,6 +267,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
else
init
build_and_install_ops
version_info
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
fi

View File

@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
from . import jit
from .jit_kernels import (
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
index fcb377e..db9d6f3 100644
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
import subprocess
-from torch.utils.cpp_extension import CUDA_HOME
+from ..paddle_utils import CUDA_HOME
def run_cuobjdump(file_path):
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
index 66c370a..4761426 100644
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
-import torch
+import paddle
from typing import Optional
from .template import map_ctype
@@ -35,7 +35,7 @@ class Runtime:
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
-import torch
+import paddle
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
}
@@ -27,25 +27,25 @@ genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
- if value.dtype == torch.int:
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
+import paddle
from functools import lru_cache
from typing import Tuple
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor) -> None:
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
- this function will do a transposing with a set of slow PyTorch operations.
+ this function will do a transposing with a set of slow paddle operations.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
n, k_ = rhs.shape
m_, n_ = out.shape
- assert n % 64 == 0 and k % 128 == 0
+ # assert n % 64 == 0 and k % 128 == 0
# Type and shape checks
- assert m == m_ and n == n_ and k == k_
- assert n > 0 and k > 0
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
+ # assert out.dtype == paddle.bfloat16
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
-import torch
+import paddle
from typing import Tuple
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
"""
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow Pypaddle operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
Values of `m_indices` in every-m-alignment-block must also be the same.
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
m__ = m_indices.numel()
# Type and shape checks
- assert m == m_ == m__ and k == k_ and n == n_
- assert lhs_scales.shape == (m, (k + 127) // 128)
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
+ # assert m_indices.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
)
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
runtime(*args)
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
- rhs: Tuple[torch.Tensor, torch.Tensor],
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
+ this function will do a transposing with a set of slow paddle operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
num_groups___ = masked_m.numel()
# Type and shape checks
- assert num_groups == num_groups_ == num_groups__ == num_groups___
- assert m == m_ and n == n_ and k == k_
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
+ # assert masked_m.dtype == paddle.int32
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
+ # assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
- assert rhs_scales.is_contiguous()
+ # assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
global includes, template
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
- torch.cuda.current_stream(), num_sms, smem_config[0])
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
-import torch
+import paddle
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
@@ -51,10 +51,10 @@ class JITTuner:
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
@@ -1,4 +1,4 @@
-import torch
+import paddle
_num_sms = None
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
_num_sms = num_sms
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
"""
global _num_sms
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
return _num_sms
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
return ceil_div(x, alignment) * alignment
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
"""
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
+ aligned_x = paddle.transpose(
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
-import torch.distributed as dist
+import paddle
+import paddle.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
+ paddle.device.cuda.synchronize()
+ paddle.device.synchronize()
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
cache.zero_()
# Warmup
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
x @ y
# Testing
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
end_event.record()
- torch.cuda.synchronize()
+ paddle.device.synchronize()
return start_event.elapsed_time(end_event) / num_tests
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
fn()
if not using_nsys:
--
2.43.0
if not using_nsys:
--
2.43.0

View File

@@ -47,7 +47,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -166,7 +166,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
lambda_tile_ids_per_batch,
@@ -203,7 +203,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_encoder,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
@@ -275,7 +275,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -347,7 +347,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -404,7 +404,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = seq_lens_this_time.dims()[0];
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
@@ -474,7 +474,7 @@ std::vector<paddle::Tensor> AppendAttention(
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -551,7 +551,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -611,7 +611,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -689,7 +689,7 @@ PD_BUILD_STATIC_OP(append_attention)
"seq_lens_decoder",
"seq_lens_this_time",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -114,8 +114,7 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -405,7 +404,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -477,8 +476,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -776,7 +774,7 @@ void MultiQueryAppendAttention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -882,7 +880,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -939,7 +937,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -974,7 +972,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1103,7 +1101,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1171,7 +1169,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1207,7 +1205,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1290,7 +1288,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1353,7 +1351,7 @@ void CascadeAppendAttentionC16Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -144,8 +144,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -504,7 +503,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -601,8 +600,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -963,7 +961,7 @@ void MultiQueryAppendC4Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1088,7 +1086,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1151,7 +1149,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1186,7 +1184,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1333,7 +1331,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1409,7 +1407,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1444,7 +1442,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1527,7 +1525,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1594,7 +1592,7 @@ void CascadeAppendAttentionC4Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -151,8 +151,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -473,7 +472,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -575,8 +574,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -900,7 +898,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1054,7 +1052,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1111,7 +1109,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1146,7 +1144,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1317,7 +1315,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1387,7 +1385,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1417,7 +1415,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1500,7 +1498,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1565,7 +1563,7 @@ void CascadeAppendAttentionC8Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[bid];

View File

@@ -41,7 +41,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -131,7 +131,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -176,7 +176,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -212,7 +212,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -247,7 +247,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -282,7 +282,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -317,7 +317,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -0,0 +1,236 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "multi_head_latent_attention_kernel.h"
template <size_t vec_size, typename T>
struct softmax_state_t {
AlignedVector<T, vec_size> o;
T m;
T d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
}
__device__ __forceinline__ softmax_state_t() {
init();
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
T other_m,
T other_d) {
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
T m_prev = m, d_prev = d;
m = m_prev > other_m ? m_prev : other_m;
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1 + other_o[i] * scale2;
}
}
__device__ __forceinline__ void normalize() {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d;
}
}
};
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
struct softmax_state_ts {
uint32_t num_tiles_ = num_tiles;
AlignedVector<T, vec_size> o[num_tiles];
float m;
float d;
__device__ __forceinline__ void init() {
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
}
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ softmax_state_ts() {
init();
}
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
#pragma unroll
for (size_t i = 0; i < vec_size; i++) {
o[tile_id][i] /= d;
}
}
};
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
__device__ __forceinline__ void produce_kv(CacheT *smem,
CacheT *kv_base_gptr,
const int * block_table_smem,
const uint32_t seq_offset_gmem,
const uint32_t seq_offset_smem,
const uint32_t kv_head_idx,
const uint32_t kv_num_heads,
const uint32_t tidx,
const uint32_t chunk_start,
const uint32_t chunk_end) {
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
if (block_id < 0) {
block_id = 0;
}
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
// 8/16 T/int8 each time
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
seq_offset_gmem < chunk_end
);
}
}
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
const CacheT* k_smem,
const uint32_t kv_idx_base,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
const uint32_t gid,
const float scale,
float *s,
softmax_state_ts<vec_size, T, num_tile_v>& st) {
const CacheT* smem;
AlignedVector<T, vec_size> q_vec;
AlignedVector<T, vec_size> k_vec;
float m_prev = st.m;
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
if (iter_base + j < iter_bound) {
if constexpr (std::is_same<T, half>::value) {
s[j] = 0.f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = 0.f;
}
#pragma unroll
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
for (uint32_t i = 0; i < vec_size; ++i) {
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
}
}
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
}
__syncthreads();
} else {
if constexpr (std::is_same<T, half>::value) {
s[j] = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
s[j] = -3.38953e38f;
}
}
st.m = st.m > s[j] ? st.m : s[j];
}
// T o_scale = hexp(m_prev - st.m);
float o_scale = __expf(m_prev - st.m);
st.d *= o_scale;
#pragma unroll
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
// s[j] = hexp(s[j] - st.m);
s[j] = __expf(s[j] - st.m);
st.d += s[j];
}
#pragma unroll
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
for (uint32_t i = 0; i < vec_size; ++i) {
st.o[tile_id][i] *= o_scale;
}
}
}
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
__device__ __forceinline__ void compute_sv(const float *s,
const CacheT *base_v_smem,
const uint32_t stage_idx,
const uint32_t iter_base,
const uint32_t iter_bound,
const uint32_t tidx,
softmax_state_ts<vec_size, T, num_tile>& st) {
const CacheT* v_smem;
AlignedVector<T, vec_size> v_vec;
#pragma unroll
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
uint32_t tile_id = vid / bdx;
#pragma unroll
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
}
}
}
}

View File

@@ -0,0 +1,560 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decode_attention_func.cuh"
#define CHECK(call) \
do \
{ \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) \
{ \
printf("CUDA Error:\n"); \
printf(" File: %s\n", __FILE__); \
printf(" Line %d:\n", __LINE__); \
printf(" Error code:%d\n", error_code); \
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
exit(1); \
} \
}while(0)
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cu_seqlens_q,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
const float in_scale,
const int num_chunks,
const int chunk_size,
const int max_seq_len,
const int num_heads,
const int head_dim) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int qid = blockIdx.x, hid = blockIdx.y;
const int seq_len_q = seq_lens_q[qid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[qid];
if (seq_len_kv == 0) return;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
return;
}
__shared__ T smem[bdy * HEAD_DIM];
__shared__ T md_smem[bdy * 2];
const int start_token_ids = cu_seqlens_q[qid];
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
T m;
T d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = __float2half(-5e4f);
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = __float2bfloat16(-3.38953e38f);
}
// merge per ty
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
T m_prev = m;
T d_prev = d;
const T m_now = multi_m[offset];
const T d_now = multi_d[offset];
m = m_prev > m_now ? m_prev : m_now;
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
d = d * scale1 + d_now * scale2;
#pragma once
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
// merge bdy
softmax_state_t<vec_size, T> st{};
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
#pragma once
for (int i = 0; i < iter_num; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
AlignedVector<OutT, vec_size> out_vec;
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
out_vec[i] = static_cast<OutT>(st.o[i]);
}
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
}
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
CacheT * __restrict__ cache_v,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cu_seqlens_q,
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float in_scale,
const uint32_t chunk_size,
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
OutT * __restrict__ out) {
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t bid = bidx, gid = threadIdx.y;
const uint32_t tidx = threadIdx.x;
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
const uint32_t kv_num_heads = gridDim.z;
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
const int *block_table_now = block_table + bid * max_block_num_per_seq;
const uint32_t num_chunks = gridDim.y;
const uint32_t chunk_id = blockIdx.y;
const uint32_t q_len = seq_lens_q[bid];
if (q_len <= 0) {
return;
}
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
if (kv_len <= 0) {
return;
}
kv_len += q_len;
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
const uint32_t q_start_idx = cu_seqlens_q[bid];
const uint32_t q_write_idx = cu_seqlens_q[bid];
if (chunk_id >= num_chunk_this_seq) {
return;
}
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
}
__syncthreads();
using VecT = AlignedVector<T, VEC_SIZE>;
VecT q_vec;
#pragma unroll
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
q_vec[i] *= scale;
}
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
}
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
uint32_t stage_idx = 0;
constexpr int loop_times = DEAL_EACH_TIME / bdy;
#pragma unroll
for (int i = 0; i < NUM_STAGES; ++i) {
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
const uint32_t k_seq_id = chunk_start + k_seq_offset;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
k_seq_id,
k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
float s[DEAL_EACH_TIME];
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
for (int iter = 0; iter < num_iters; ++iter) {
wait_group<NUM_STAGES - 1>();
__syncthreads();
// compute qk
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
cu_q_smem,
kv_smem,
chunk_start + iter * DEAL_EACH_TIME,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
gid,
scale,
s,
st
);
__syncthreads();
// compute sv
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
s,
kv_smem,
stage_idx,
iter * DEAL_EACH_TIME,
chunk_len,
tidx,
st
);
__syncthreads();
#pragma unroll
for (int j = 0; j < loop_times; ++j) {
const uint32_t k_seq_offset = j * bdy + gid;
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
kv_smem,
cache_k,
block_table_now,
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
stage_idx * DEAL_EACH_TIME + k_seq_offset,
kv_head_idx,
kv_num_heads,
tidx,
chunk_start,
chunk_end
);
}
commit_group();
stage_idx = (stage_idx + 1) % NUM_STAGES;
}
wait_group<0>();
__syncthreads();
// normize if not partition_kv
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
const uint32_t tile_id = vid / bdx;
if (!partition_kv || num_chunk_this_seq == 1) {
st.normalize(tile_id);
}
if (partition_kv && num_chunk_this_seq > 1) {
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
tmp_m[head_idx] = st.m;
tmp_d[head_idx] = st.d;
} else {
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
}
}
}
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
void MultiQueryDecoderAttention(
const AppendAttnMetaData& meta_data,
cudaStream_t &stream,
const paddle::Tensor &q,
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const int max_seq_len,
const int max_dec_len,
const float rope_scale,
const float rope_theta,
const float softmax_scale,
const float in_scale,
paddle::Tensor *out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto token_num = meta_data.token_nums;
auto bsz = meta_data.batch_size;
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
constexpr int num_stages = NUM_STAGE;
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
constexpr int blocky = GROUP_SIZE;
const int gridx = bsz;
constexpr int num_threads = blockx * blocky;
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
uint32_t cache_smem_bytes = 0;
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
const uint32_t chunk_size = get_max_partition_size(bsz);
const int num_chunks = div_up(max_dec_len, chunk_size);
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
assert(act_blocks_per_sm > 1);
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
dim3 grids(gridx, num_chunks, kv_num_heads);
dim3 blocks(blockx, blocky);
if (num_chunks <= 1) {
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
cache_vec_size, blockx, blocky>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
nullptr,
nullptr,
nullptr,
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
} else {
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
tmp_workspace = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
tmp_m = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads));
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
softmax_scale,
in_scale,
chunk_size,
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
);
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
constexpr int mblockx = HEAD_DIM_V / vec_size;
constexpr int bdy = 256 / mblockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(mblockx, bdy);
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
in_scale,
num_chunks,
chunk_size,
max_seq_len,
num_heads,
HEAD_DIM_V
);
}
// CHECK(cudaGetLastError());
// CHECK(cudaDeviceSynchronize());
}
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim_qk = meta_data.head_dims;
const auto head_dim_v = meta_data.head_dims_v;
const float rope_scale = 0.0;
const float rope_theta = 0.0;
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
const uint32_t num_stage = get_cascade_attention_num_stages();
const uint32_t num_threads = get_cascade_attention_num_threads();
DISPATCH_CAUSAL(causal, CAUSAL,
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);
template void DecodeMLAAttentionKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -29,7 +29,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -135,7 +135,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -255,7 +255,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -367,7 +367,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -499,7 +499,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -746,7 +746,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1048,7 +1048,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1347,7 +1347,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1740,7 +1740,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2035,7 +2035,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2363,7 +2363,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2733,7 +2733,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;

View File

@@ -22,7 +22,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -58,7 +58,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -80,7 +80,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -103,7 +103,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -126,7 +126,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -150,7 +150,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -183,7 +183,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -208,7 +208,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -233,7 +233,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -258,7 +258,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -283,7 +283,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -318,7 +318,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -345,7 +345,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -372,7 +372,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -399,7 +399,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -425,7 +425,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -472,7 +472,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -504,7 +504,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -537,7 +537,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -571,7 +571,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -604,7 +604,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -651,7 +651,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -678,7 +678,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -704,7 +704,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -730,7 +730,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -879,7 +879,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -911,8 +911,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1119,7 +1118,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -1148,8 +1147,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1750,7 +1748,7 @@ void CascadeAppendWriteCacheKVC8QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1815,7 +1813,7 @@ void CascadeAppendWriteCacheKVC8QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
@@ -1838,7 +1836,7 @@ void CascadeAppendWriteCacheKVC4QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1885,7 +1883,7 @@ void CascadeAppendWriteCacheKVC4QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,

View File

@@ -26,7 +26,7 @@ void EncoderWriteCacheWithRopeKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,
@@ -143,7 +143,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
batch_ids,
tile_ids,
@@ -170,7 +170,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
batch_ids,
tile_ids,

View File

@@ -422,7 +422,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids,
@@ -450,7 +449,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const int token_num = qkv_dims[0];
const int max_blocks_per_seq = block_tables.dims()[1];
const int block_size = key_cache.dims()[2];
const int batch_size = cum_offsets.dims()[0];
const int batch_size = seq_lens_this_time.dims()[0];
const int kv_num_heads = key_cache_dims[1];
const int head_dim = key_cache_dims[3];
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
@@ -463,7 +462,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data.q_num_heads = num_heads;
meta_data.max_blocks_per_seq = max_blocks_per_seq;
meta_data.block_size = block_size;
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = seq_lens_this_time.dims()[0];
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
@@ -528,7 +527,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids,
@@ -595,7 +594,6 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
"seq_lens_encoder",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"block_tables",
"kv_batch_ids",
"kv_tile_ids_per_batch",

View File

@@ -0,0 +1,291 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mla_cache_kernel.cuh"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
constexpr int PackSize = 16 / sizeof(DataType_);
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
return {};
}
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto bsz = meta_data.batch_size;
auto token_num = meta_data.token_nums;
auto block_size = meta_data.block_size;
auto nope_size = meta_data.head_dims_v;
auto all_size = meta_data.head_dims;
int pe_size = all_size - nope_size;
auto kv_num_heads = meta_data.kv_num_heads;
constexpr int PackSize = 16 / sizeof(DataType_);
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
} else {
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
max_blocks_per_seq,
kv_num_heads,
nope_size,
pe_size,
block_size,
elem_nums);
}
return {};
}
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
const auto& kv_nope_dims = kv_nope.dims();
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_OP(decode_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));

View File

@@ -0,0 +1,242 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "mem_util.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1>
__global__ void decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
start_token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
start_token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void speculate_decode_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_id * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_id * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}
template <typename T, int VecSize = 1>
__global__ void prefill_absorb_cache_kernel(
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
const int kv_num_heads,
const int nope_size,
const int pe_size,
const int block_size,
const uint32_t elem_cnt) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
const uint32_t all_size = nope_size + pe_size;
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
const uint32_t block_offset = ori_seq_id % block_size;
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
const uint32_t hi = inner_bias / nope_size;
const uint32_t h_bias = inner_bias % nope_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + h_bias;
const uint32_t ori_idx =
token_idx * nope_hidden_size + inner_bias;
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
} else {
const uint32_t inner_bias = bias - nope_hidden_size;
const uint32_t hi = inner_bias / pe_size;
const uint32_t h_bias = inner_bias % pe_size;
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
hi * block_size * all_size +
block_offset * all_size + nope_size + h_bias;
const uint32_t ori_idx =
token_idx * pe_hidden_size + inner_bias;
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
}
}
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "helper.h"
#include "utils.cuh"
template <typename T>
void DecodeMLAAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
const paddle::Tensor &cache_k,
const paddle::Tensor &cache_v,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& shift_bias,
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
float softmax_scale,
float in_scale,
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

View File

@@ -27,7 +27,7 @@ __global__ void append_clear_cache_int8_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -44,7 +44,7 @@ __global__ void append_clear_cache_int8_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -101,7 +101,7 @@ __global__ void append_clear_cache_int4_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -118,7 +118,7 @@ __global__ void append_clear_cache_int4_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -179,7 +179,7 @@ __global__ void append_speculate_cache_rope_kernel(
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -219,7 +219,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cum_offsets[ori_bi]);
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -312,7 +312,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -352,7 +352,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cum_offsets[ori_bi]);
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -459,7 +459,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -487,7 +487,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -691,7 +691,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -719,7 +719,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
@@ -1069,7 +1069,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1100,7 +1100,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -1375,7 +1375,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1406,7 +1406,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;

View File

@@ -23,7 +23,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -60,7 +60,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
@@ -83,7 +83,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
@@ -107,7 +107,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -137,7 +137,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -152,7 +152,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -176,7 +176,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -202,7 +202,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -234,7 +234,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -249,7 +249,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -275,7 +275,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -302,7 +302,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -350,7 +350,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -377,7 +377,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -410,7 +410,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -443,7 +443,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -489,7 +489,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -515,7 +515,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -540,7 +540,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -567,7 +567,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -39,7 +39,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -23,7 +23,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
int kv_num_heads;
int token_nums;
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
};
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 128: { \
constexpr size_t HEAD_DIM = 128; \
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
case 512: { \
constexpr size_t HEAD_DIM = 512; \
__VA_ARGS__ \
break; \
} \
case 576: { \
constexpr size_t HEAD_DIM = 576; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
}
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
if (num_stage == 2) { \
constexpr size_t NUM_STAGE = 2; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the num_stage: ", num_stage); \
}
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
constexpr size_t cache_bytes = 4; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the cache_type: ", cache_type); \
}
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
if (deal_each_time == 32) { \
if (deal_each_time == 32) { \
constexpr size_t DEAL_EACH_TIME = 32; \
__VA_ARGS__ \
} else if (deal_each_time == 64) { \
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
}
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
if (block_shape_q <= 16) { \
constexpr size_t BLOCK_SHAPE_Q = 16; \

View File

@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn1_scale,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &ffn2_scale,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_scale,
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
std::vector<paddle::Tensor>
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
const paddle::Tensor &topk_weights,
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
const std::vector<int> &token_nums_per_expert,
const int token_nums_this_rank,
const std::string &moe_quant_type);
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor &input, const paddle::Tensor &scale,
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
const paddle::Tensor &token_nums_per_expert,
const paddle::Tensor &token_nums_per_expert_padded);
const paddle::Tensor &token_nums_per_expert_padded,
const bool use_in_ep, const int token_nums_this_rank_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size);
@@ -172,7 +173,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
@@ -181,35 +182,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
@@ -316,6 +317,95 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
int64_t num_experts);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch);
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool speculate_decoder);
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len);
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox);
std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& query_bias,
const paddle::optional<paddle::Tensor>& query_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const int nope_size,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder);
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
@@ -370,6 +460,270 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
paddle::Tensor const &input,
paddle::Tensor &scales, float scale_ub);
std::vector<paddle::Tensor> NoauxTc(
paddle::Tensor& scores,
paddle::Tensor& scores_with_bias,
int n_group,
int topk_group,
int topk,
float routed_scaling_factor);
#ifdef ENABLE_FP8
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type);
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled);
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale);
#endif
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
int64_t meta_size();
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
void register_graph_buffers(int64_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(paddle::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
// speculative decoding Kernel
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len);
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len);
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids);
void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums);
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx);
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags);
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep);
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num);
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens);
void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -461,7 +815,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* ep_moe_dispatch
*/
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
py::arg("moe_quant_type"), "ep moe export dispatch function");
@@ -469,7 +823,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"ep moe export combine function");
@@ -511,7 +865,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
*/
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"moe export reduce function");
@@ -539,9 +893,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attn/get_block_shape_and_split_kv_block.cu
* get_block_shape_and_split_kv_block
*/
// m.def("f_get_block_shape_and_split_kv_block",
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
// function");
m.def("get_block_shape_and_split_kv_block",
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");
/**
* get_padding_offset.cu
@@ -602,32 +955,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
py::arg("a"),
py::arg("c_or_none"),
py::arg("b_q_weight"),
py::arg("b_scales"),
py::arg("global_scale_or_none"),
py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"),
py::arg("perm_or_none"),
py::arg("workspace"),
py::arg("sorted_token_ids"),
py::arg("expert_ids"),
py::arg("num_tokens_post_padded"),
py::arg("topk_weights"),
py::arg("moe_block_size"),
py::arg("top_k"),
py::arg("mul_topk_weights"),
py::arg("is_ep"),
py::arg("b_q_type_str"),
py::arg("size_m"),
py::arg("size_n"),
py::arg("size_k"),
py::arg("is_k_full"),
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
/**
* cutlass_scaled_mm.cu
@@ -653,4 +990,80 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
"dynamic_per_token_scaled_fp8_quant function",
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
#ifdef ENABLE_FP8
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
#endif
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
m.def("all_reduce", &all_reduce, "all reduce function");
m.def("dispose", &dispose, "del function for python");
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
m.def("register_buffer", &register_buffer, "register ipc buffer");
m.def("register_graph_buffers", &register_graph_buffers, "register_graph_buffers");
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
// speculative decoding Kernel
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");
m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");
m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");
m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");
m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
}

View File

@@ -0,0 +1,165 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
paddle::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
rank_data.numel(), rank, world_size,
full_nvlink);
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
auto input_size = inp.numel() * 2;
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
cudaMemcpyDeviceToDevice, stream);
} else {
reg_buffer = inp.data();
}
switch (out.dtype()) {
case phi::DataType::FLOAT32: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data()),
out.numel());
break;
}
case phi::DataType::FLOAT16: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data()), out.numel());
break;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
case phi::DataType::BFLOAT16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
}
int64_t meta_size() { return sizeof(paddle::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
CUDACHECK(cudaMalloc((void**)&buffer, size));
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
CUDACHECK(cudaStreamSynchronize(stream));
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto handle =
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
CUDACHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}

View File

@@ -0,0 +1,526 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <array>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace paddle {
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Paddle, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
: "=r"(flag)
: "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr =
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val);
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " +
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
sizeof(RankData) * num_buffers,
cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
};
} // namespace paddle

View File

@@ -76,6 +76,34 @@ enum class SplitKStyle
// SPLIT_K_PARALLEL // Not supported yet
};
// New enum for SM100 (Blackwell) Tile Configs
// Placeholder values - actual optimal values need research
enum class CutlassTileConfigSM100
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// Actual SM100 tile configs based on user input (K-tile is 128B)
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
// These are already covered by the list above if general suffices.
// If they need distinct enum values, they should be added.
// For now, keeping the enum concise with unique shapes mentioned for general use.
};
enum class CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
HOPPER = 1u << 3, // SM90
GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5,
BLACKWELL = 1u << 6, // SM100
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
};
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
@@ -149,7 +179,17 @@ struct CutlassGemmConfig
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false;
CutlassGemmConfig() {}
// config options for sm100 (Blackwell)
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
// These might need to become SM100-specific if Blackwell introduces new concepts.
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
bool is_sm100 = false;
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config)
@@ -157,37 +197,64 @@ struct CutlassGemmConfig
, split_k_factor(split_k_factor)
, stages(stages)
, is_sm90(false)
, is_sm100(false)
{
}
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
// Constructor for SM90
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
: tile_config_sm90(tile_config_sm90_in)
, mainloop_schedule(mainloop_schedule_in)
, epilogue_schedule(epilogue_schedule_in)
, cluster_shape(cluster_shape_in)
, is_sm90(true)
, is_sm100(false)
{
}
// Constructor for SM100 (Blackwell)
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
: tile_config_sm100(tile_config_sm100_in)
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
, is_sm90(false) // Explicitly false
, is_sm100(true)
{
}
std::string toString() const
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
{
assert(is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA"
<< "\n\ttile shape ID: " << (int) tile_config_sm90
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
<< "\n\ttile shape ID: " << (int) tile_config_sm100
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
tactic << "\n\tstyle=TMA_SM90"
<< "\n\ttile shape ID: " << (int) tile_config_sm90
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
assert(!is_sm90 && "Invalid cutlass GEMM config");
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config
<< "\n\ttile shape ID: " << (int) tile_config
<< "\n\tstages: " << (int) stages
<< "\n\tsplit_k_style: " << (int) split_k_style
<< "\n\tsplit k: " << (int) split_k_factor;
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
std::istringstream stream(str);
std::string line;
is_sm90 = false; // Reset flags
is_sm100 = false;
while (std::getline(stream, line)) {
if (line.find("style=TMA") != std::string::npos) {
if (line.find("style=TMA_SM100") != std::string::npos) {
is_sm100 = true;
is_sm90 = false;
std::getline(stream, line);
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
is_sm90 = true;
is_sm100 = false;
std::getline(stream, line);
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=compatible") != std::string::npos) {
is_sm90 = false;
is_sm100 = false;
std::getline(stream, line);
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
{
// clang-format off
if (config.is_sm90)
if (config.is_sm100)
{
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
}
else if (config.is_sm90)
{
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)

View File

@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
#endif
}
// SM100 (Blackwell) candidate tile configurations
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
return {CutlassTileConfigSM100::CtaShape128x128x128B};
#else
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
{
return {
/* 1 SM (M=128) */
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM (M=256) */
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B,
/* slim tiles for very tall matrices */
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape256x64x128B};
}
/* Fp8 / Fp16 grouped-GEMM */
return {
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
}
/* Non-grouped path (plain GEMM or weight-only) */
return {
/* 1 SM tiles */
CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM tiles */
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
#endif
}
// M-multicast support for SM100.
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> m_tiles{
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
return m_tiles.count(tile) == 1;
#endif
}
// N-multicast support for SM100.
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> n_tiles{
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B};
return n_tiles.count(tile) == 1;
#endif
}
std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
}
return candidate_configs;
}
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
{
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config_sm100 : tiles)
{
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
// Cluster shapes are also handled similarly.
CutlassGemmConfig config(
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
if (has_m_mcast)
{
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(mcast_m_config);
}
if (has_n_mcast)
{
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(mcast_n_config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(mcast_mn_config);
}
}
return candidate_configs;
}
// Fallback to older architecture configurations
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
// It's fine here as it's within an else if / else block.
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);

View File

@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
ffn1_n=7168
ffn1_k=8192
up_gate_proj_n=7168
up_gate_proj_k=8192
ffn2_n=8192
ffn2_k=3584
rm -rf ffn1_7168_8192.log
rm -rf ffn2_8192_3584.log
down_proj_n=8192
down_proj_k=3584
rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
num_experts=8
for tokens_per_expert in 12
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
done
wait
echo "#### finish ####"

64
custom_ops/gpu_ops/env.h Normal file
View File

@@ -0,0 +1,64 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}
inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
}
inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
return max_partition_size;
}
inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}
inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}
inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}
inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}

View File

@@ -19,7 +19,7 @@
#include "fp8_fp8_half_cuda_core_gemm.h"
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
@@ -142,7 +142,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
{
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
} else {
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
}
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
}
return {out};
return out;
}
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type) {
return {cutlass_fp8_fp8_half_gemm_func(
x, y, bias, trans_x, trans_y, scale,
output_dtype, activation_type)};
}
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(

View File

@@ -0,0 +1,198 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include "helper.h"
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
int lane_id = threadIdx.x % 32;
#pragma unroll
for (int step = 0; step < 5; ++step) {
const int lane_mask = 1 << step;
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
x = sign * x + x_val_other;
}
}
__global__ void MoeFusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (out_idx >= numel) return;
int64_t token_idx = out_idx / (top_k * intermediate_size);
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
int64_t inter_idx = out_idx % intermediate_size;
int64_t input_idx = token_idx * intermediate_size + inter_idx;
if (input_idx >= numel / top_k) return;
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[input_idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
int64_t token_idx = idx / intermediate_size;
int64_t expert_id = topk_ids[token_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
int64_t numel = input.numel();
if (!tiled) numel *= top_k;
paddle::Tensor out = GetEmptyTensor(
{numel / intermediate_size, intermediate_size},
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
if (tiled) {
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
} else {
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
}
return {out};
}
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
.Inputs({"input", "scale", "topk_ids"})
.Outputs({"output"})
.Attrs({"top_k: int",
"intermediate_size: int",
"tiled: bool"})
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
}
__global__ void FusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
__nv_fp8_e4m3* out,
const float scale,
const int64_t numel) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
const paddle::Tensor &input,
const float scale) {
int64_t numel = input.numel();
paddle::Tensor out = GetEmptyTensor(
input.dims(),
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
scale,
numel
);
return {out};
}
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"scale: float"})
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale) {
return FusedHadamardQuantFp8(input, scale)[0];
}

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_kernel(
T* __restrict__ arr,
const T* __restrict__ cos_ptr,
const T* __restrict__ sin_ptr,
int rot_offset,
int embed_dim) {
int x_index, y_index;
T cos, sin;
if (IS_NEOX) {
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = cos_ptr[x_index];
sin = sin_ptr[x_index];
} else {
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = cos_ptr[x_index / 2];
sin = sin_ptr[x_index / 2];
}
const T x = arr[x_index];
const T y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template <typename T, bool IS_NEOX>
__global__ void apply_rotary_embedding_kernel(
T* __restrict__ query, // [num_tokens, num_heads, head_size]
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int* __restrict__ position_ids, // [num_tokens]
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int pos = position_ids[token_idx];
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const T* cos_ptr = cache_ptr;
const T* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
void FusedRotaryPositionEncoding(
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
// [num_tokens, num_heads * head_size]
paddle::Tensor& key,
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
// head_size]
const paddle::Tensor& position_ids, // [num_tokens]
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
int head_size,
bool is_neox) {
int64_t num_tokens = query.dims()[0];
int num_heads = query.numel() / num_tokens / head_size;
int num_kv_heads = key.numel() / num_tokens / head_size;
int rot_dim = cos_sin_cache.dims()[1];
int64_t query_stride = num_heads * head_size;
int64_t key_stride = num_kv_heads * head_size;
if (num_tokens > 65535) {
PD_THROW(
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
}
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
query.dtype(), "apply_rotary_embedding_kernel", [&] {
if (is_neox) {
apply_rotary_embedding_kernel<data_t, true>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
apply_rotary_embedding_kernel<data_t, false>
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
key.data<data_t>(),
position_ids.data<int>(),
cos_sin_cache.data<data_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
PD_BUILD_OP(fused_rotary_position_encoding)
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
.Outputs({"query_out", "key_out"})
.Attrs({"head_size: int", "is_neox: bool"})
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));

View File

@@ -24,16 +24,18 @@
#endif
#define MAX_BSZ 512
#define K 10
#define K 20
struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};
void GetOutputTopK(const paddle::Tensor& x,
const paddle::Tensor& scores,
const paddle::Tensor& ranks,
int k,
int64_t rank_id,
bool wait_flag) {
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
float* scores_data = const_cast<float*>(scores.data<float>());
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
0);
}
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
}
return;
}
PD_BUILD_STATIC_OP(get_output_topk)
.Inputs({"x", "scores"})
.Inputs({"x", "scores", "ranks"})
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out", "scores_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
.Outputs({"x_out", "scores_out", "ranks_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
.SetKernelFn(PD_KERNEL(GetOutputTopK));

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -59,7 +60,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
#endif
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
@@ -75,7 +81,11 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
#ifdef PADDLE_WITH_COREX
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#else
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
const int* seq_lens_this_time,
int* position_ids, // 输出的一维 position_ids
int* mask_encoder_batch,
const int bsz) { // 批次大小
// 当前线程索引(每个线程对应一个批次)
int tid = threadIdx.x;
if (tid >= bsz) return;
// 动态计算当前批次的偏移量
int offset = 0;
for (int i = 0; i < tid; i++) {
offset += seq_lens_encoder[i];
if (seq_lens_decoder[i] > 0) {
offset += seq_lens_this_time[i];
}
}
// 当前批次的 encoder 和 decoder 长度
int encoder_len = seq_lens_encoder[tid];
int decoder_len = seq_lens_decoder[tid];
int seq_len_this_time = seq_lens_this_time[tid];
// 写入 encoder 的 position_ids
for (int i = 0; i < encoder_len; i++) {
position_ids[offset + i] = i;
mask_encoder_batch[offset + i] = 1;
}
offset += encoder_len;
// 写入 decoder 的 position_ids
if (decoder_len > 0) {
for (int i = 0; i < seq_len_this_time; i++) {
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
mask_encoder_batch[offset + i] = 0;
}
}
}
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids,
const paddle::Tensor& mask_encoder_batch) {
const int bsz = seq_lens_this_time.shape()[0];
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
const_cast<int*>(position_ids.data<int>()),
const_cast<int*>(mask_encoder_batch.data<int>()),
bsz);
}
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
.Inputs({"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"position_ids",
"mask_encoder_batch"})
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"mask_encoder_batch", "mask_encoder_batch_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));

View File

@@ -14,7 +14,9 @@
#pragma once
#ifndef PADDLE_WITH_COREX
#include "glog/logging.h"
#endif
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
@@ -35,20 +37,35 @@ namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#endif
#ifndef PADDLE_WITH_COREX
#include "nlohmann/json.hpp"
#endif
#include <fstream>
#include <iostream>
#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_context.h"
#else
#include "paddle/phi/core/cuda_stream.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#ifdef PADDLE_WITH_COREX
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#ifndef PADDLE_WITH_COREX
using json = nlohmann::json;
#endif
#define CUDA_CHECK(call) \
do { \
@@ -197,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
*addr_vec = vec;
}
#ifdef PADDLE_WITH_HIP
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
}
#else
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
}
#endif
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
@@ -235,6 +260,7 @@ inline int GetBlockSize(int vocab_size) {
}
}
#ifndef PADDLE_WITH_COREX
inline json readJsonFromFile(const std::string &filePath) {
std::ifstream file(filePath);
if (!file.is_open()) {
@@ -245,6 +271,7 @@ inline json readJsonFromFile(const std::string &filePath) {
file >> j;
return j;
}
#endif
#define cudaCheckError() \
{ \
@@ -416,6 +443,7 @@ inline std::string base64_decode(const std::string &encoded_string) {
return ret;
}
#ifndef PADDLE_WITH_COREX
template <typename T>
inline T get_relative_best(nlohmann::json *json_data,
const std::string &target_key,
@@ -428,6 +456,7 @@ inline T get_relative_best(nlohmann::json *json_data,
return default_value;
}
}
#endif
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
int length) {
@@ -457,7 +486,12 @@ template <typename T>
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
std::vector<T> tmp(num);
#ifdef PADDLE_WITH_HIP
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
#else
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
#endif
std::ofstream outfile;
outfile.open(name + ".txt", std::ios::out);
@@ -474,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
outfile.close();
}
#ifndef PADDLE_WITH_HIP
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
@@ -513,3 +548,11 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in;
}
#endif
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
}

View File

@@ -0,0 +1,255 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#include <cute/tensor.hpp>
#include <cutlass/detail/helper_macros.hpp>
#include "utils.cuh"
namespace mla_attn {
using namespace cute;
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
};
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
};
template <int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template <>
struct Allreduce<2> {
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator& op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
Tensor<Engine1, Layout1>& src, Operator& op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& summary, Operator& op) {
thread_reduce_<init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& max) {
MaxOp<float> max_op;
reduce_<init>(tensor, max, max_op);
}
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
Tensor<Engine1, Layout1>& sum) {
SumOp<float> sum_op;
thread_reduce_<init>(tensor, sum, sum_op);
if constexpr (warp_reduce) {
quad_allreduce_(sum, sum, sum_op);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
Tensor<Engine1, Layout1> const& max,
const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
auto row_max = max(mi);
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// row_max * scale is a constant for each row, so we can use fma here
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
}
}
}
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
struct OnlineSoftmax {
constexpr static float fill_value = -5e4;
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
TensorT row_max, row_sum, scores_scale;
float sm_scale_log2;
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
clear(scores_scale);
};
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
template <bool init, typename Tensor0>
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
if constexpr (init) {
reduce_max</*init=*/true>(scores, row_max);
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
} else {
// update row_max
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*init=*/false>(scores, row_max);
// update scores_scale and scale row_sum
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
if constexpr (WITH_SCALE) {
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
} else {
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
}
row_sum(mi) *= scores_scale(mi);
}
// perform exp2 on scores
if constexpr (WITH_SCALE) {
scale_apply_exp2(scores, row_max, sm_scale_log2);
} else {
apply_exp2(scores, row_max);
}
// update row_sum
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
return scores_scale;
}
};
template <typename Tensor0>
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.f / sum;
scores_scale(mi) = inv_sum;
row_max(mi) *= sm_scale_log2;
}
return scores_scale;
};
template <typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
template <typename Tensor1, typename Tensor2>
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
}
}
};
};
} // namespace mla_attn

View File

@@ -0,0 +1,232 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "cute/tensor.hpp"
#include "mla_hopper.cuh"
#include <iostream>
#include <string>
#include <sstream>
#include "batch_mla_with_paged_kv_cache.h"
#include "env.h"
using namespace cute;
using namespace mla_attn;
using namespace std;
template <typename T>
struct cascade_type_traits {
using type = T;
using cutlass_type = T;
};
template <>
struct cascade_type_traits<phi::dtype::bfloat16> {
using type = __nv_bfloat16;
using cutlass_type = cutlass::bfloat16_t;;
};
template <>
struct cascade_type_traits<phi::dtype::float16> {
using type = half;
using cutlass_type = cutlass::half_t;
};
template <>
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
using type = __nv_fp8_e4m3;
using cutlass_type = cutlass::float_e4m3_t;
};
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out) {
using NV_TYPE = typename cascade_type_traits<T>::type;
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto bsz = meta_data.batch_size;
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
int q_head_dim = meta_data.head_dims;
int k_head_dim = meta_data.head_dims;
int v_head_dim = meta_data.head_dims_v;
// int num_chunks = max_dec_len / chunk_size;
int num_chunks = div_up(max_dec_len, chunk_size);
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
O_tmp = allocator->Allocate(
phi::SizeOf(q.dtype()) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
m_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
d_tmp = allocator->Allocate(
sizeof(float) *
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
params.m = reinterpret_cast<float*>(m_tmp->ptr());
params.d = reinterpret_cast<float*>(d_tmp->ptr());
params.block_tables = const_cast<int*>(block_tables.data<int>());
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
params.num_blocks_x_int = num_blocks_x;
params.q_stride_bsz = q_head_num * q_head_dim;
params.q_stride_head_num = q_head_dim;
params.kv_stride_block_num = block_size * k_head_dim;
params.kv_stride_block_size = k_head_dim;
params.o_stride_bsz = q_head_num * v_head_dim;
params.o_stride_head_num = v_head_dim;
params.bsz = bsz;
params.token_num = token_num;
params.max_seq_len = max_seq_len;
params.max_block_num = max_block_num;
params.max_block_num_per_seq = max_block_num_per_seq;
params.q_num_head = q_head_num;
params.qk_head_dim = q_head_dim;
params.vo_head_dim = v_head_dim;
params.block_size = block_size;
params.max_draft_token_num = draft_token_num;
params.sm_scale = softmax_scale;
params.chunk_size = chunk_size;
params.chunk_num = num_chunks;
if (q_head_dim == 576) {
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
params, stream
);
} else {
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
}
}
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,68 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/allocator.h"
#include "append_attn/utils.cuh"
template <typename T>
void BatchMLAWithPagedKVCacheKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int draft_token_num,
const bool causal,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -0,0 +1,175 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
#define ATTENTION_HOPPER_EPILOGUE_CUH_
#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits>
struct CollectiveEpilogue {
using DTypeO = typename Ktraits::DTypeO;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
using StrideNTMAT = cute::Shape<int32_t, _1>;
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
using TiledCopyOValLayout =
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
using TiledCopyO =
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
TiledCopyOValLayout{} // Val layout
));
struct Arguments {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
// Device side kernel params
struct Params {
DTypeO* O_ptr;
LayoutNTMAT const layout_O;
DTypeO* O_ptr_tmp;
LayoutNTMAT const layout_O_tmp;
};
static Params to_underlying_arguments_ntma(Arguments const& args) {
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
typename TiledMma>
CUTLASS_DEVICE void store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
const int thread_idx,
const int bid,
const int bsz,
const int seq_len_now,
const int start_token_idx,
const int tile_idx,
const int kv_len,
const int chunk_size,
const int max_draft_token_num,
const int o_stride_bsz) {
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// make sure gemm done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
// r2s
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
// make sure r2s done
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
TiledCopyO gmem_tiled_copy_O;
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
// copy if not out of bound
auto predicate_fn = [&](auto coords) {
auto s_coords = tOcOGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
};
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_

View File

@@ -0,0 +1,163 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
#include <type_traits>
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
namespace mla_attn {
using namespace cute;
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
struct alignas(16) SharedStorageQKVO {
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
union {
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
};
};
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
struct AttentionKernelTraits {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
using DTypeQKAccum = float;
using DTypePVAccum = float;
using NV_TYPE = NV_TYPE_;
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
static constexpr int GROUP_SIZE = GROUP_SIZE_;
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
static_assert(BLOCK_SHAPE_Q % 64 == 0);
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
static_assert(HEAD_DIM_QK % 32 == 0);
static_assert(HEAD_DIM_VO % 32 == 0);
static constexpr int NUM_WARPS = 12;
static constexpr int NUM_THREADS = 384;
static constexpr int NUM_PRODUCER_THREADS = 128;
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
static constexpr int NUM_STAGES = NUM_STAGES_;
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
using TiledMmaQK = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
using TiledMmaPV = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutPV{}));
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutK = decltype(tile_to_shape(
SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
using SmemLayoutVt = decltype(composition(
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutV = decltype(tile_to_shape(
SmemLayoutAtomV{},
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVtOneStage = decltype(composition(
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
get<2>(TileShape_PDV{}), Int<1>{}),
Step<_2, _1, _3>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
decltype(cute::get<1>(TileShape_QKD{}))>());
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
using PipelineStateQ = typename cutlass::PipelineState<1>;
using MainloopPipeline =
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
typename cutlass::PipelineAsync<NUM_STAGES>>;
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
};
} // namespace mla_attn
#endif

View File

@@ -0,0 +1,348 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename Ktraits, bool CAUSAL>
struct CollectiveMainloop {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = float;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomQ = Layout<
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyQ = decltype(make_tiled_copy(
GmemCopyAtomQ{},
GmemLayoutAtomQ{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using ShapeQT = cute::Shape<int32_t, int32_t>;
using StrideQT = cute::Shape<int32_t, _1>;
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeMDT = cute::Shape<int32_t, int32_t>;
using StrideMDT = cute::Shape<int32_t, _1>;
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)), StrideT{}
),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_QKD{}),
_1{})); // no mcast for KV
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
static constexpr uint32_t TmaTransactionBytesQ =
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
static constexpr uint32_t TmaTransactionBytesKV =
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
// Host side kernel arguments
struct Arguments {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ const* Q_ptr;
DTypeKV const* KV_ptr;
DTypeMD const* m_ptr;
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
};
// Device side kernel params
struct Params {
LayoutQT layout_Q;
LayoutT layout_KV;
LayoutMDT layout_MD;
DTypeQ *Q_ptr;
DTypeKV* KV_ptr;
DTypeMD* m_ptr;
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
float sm_scale;
int bsz;
int max_block_num;
int max_block_num_per_seq;
int q_stride_bsz;
int q_stride_head_num;
int kv_stride_block_num;
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
};
static Params to_underlying_arguments(Arguments const& args) {
TMA_KV tma_load_KV;
if constexpr (USE_TMA_LOAD_KV) {
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
tma_load_KV =
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
}
return {args.layout_Q,
args.layout_KV,
args.layout_MD,
const_cast<DTypeQ*>(args.Q_ptr),
const_cast<DTypeKV*>(args.KV_ptr),
const_cast<DTypeMD*>(args.m_ptr),
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
args.sm_scale,
args.bsz,
args.max_block_num,
args.max_block_num_per_seq,
args.q_stride_bsz,
args.q_stride_head_num,
args.kv_stride_block_num,
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
if constexpr (USE_TMA_LOAD_KV) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_write_q,
SharedStorage& shared_storage,
const int thread_idx,
const int bid) {
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
Tensor gQ =
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor cQ = cute::make_identity_tensor(gQ.shape());
GmemTiledCopyQ gmem_tiled_copy_q;
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
Tensor tQcQGroup = flatten_1(tQcQ);
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
auto q_predicate_fn = [&](auto coords) {
auto s_coords = tQcQGroup(_0{}, coords);
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
};
Tensor tQgQiGroup = flatten_1(tQgQ);
Tensor tQsQiGroup = flatten_1(tQsQ);
pipeline_q.producer_acquire(smem_pipe_write_q);
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_q;
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
GmemTiledCopy gmem_tiled_copy_kv;
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
Tensor tKsKiGroup =
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_kv;
}
}
template <typename SharedStorage>
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_write_kv,
SharedStorage& shared_storage,
const int bid,
const int kv_len,
const int tile_idx) {
int thread_idx = threadIdx.x;
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
// Prepare the TMA loads
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
auto [tKgK, tKsK] =
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
#pragma unroll 2
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
const int block_idx = kv_block_tables(bid, kv_tile_idx);
pipeline_kv.producer_acquire(smem_pipe_write_kv);
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
++smem_pipe_write_kv;
}
}
}
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_

View File

@@ -0,0 +1,500 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "named_barrier.cuh"
// #define DEBUG_MLA
namespace mla_attn {
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
bool is_first_step = true;
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
#pragma unroll 1
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm qk
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
is_first_step = false;
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
cute::copy(scale_o, tScalesScale);
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure r2s all done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
attention_updater.rescale_o(tOrO, scale_o);
// pv gemm
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// normalize
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale);
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
attention_updater.rescale_o(tOrO, scale_o);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
// wait kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale, scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV1(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
tOrV2(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
// sync WG1 WG2
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
cute::copy(tScalesScale, scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
MainloopPipelineQ pipeline_q,
PipelineStateQ& smem_pipe_read_q,
MainloopPipeline pipeline_kv,
PipelineState& smem_pipe_read_kv,
FrgTensorO& tOrO,
AttentionUpdater& attention_updater,
const int thread_idx,
const int bid,
const int kv_len,
const int qo_len,
const int tile_idx,
SharedStorage& shared_storage) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutP = typename Ktraits::SmemLayoutP;
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
typename Ktraits::TiledMmaQK tiled_mma_qk;
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 1) {
// consumer 0, compute qk
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
// wait q
consumer_wait(pipeline_q, smem_pipe_read_q);
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
// wait k
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// first qk gemm
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
// mask
{
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
--kv_tile_idx;
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
++smem_pipe_read_kv;
// wait next kv
consumer_wait(pipeline_kv, smem_pipe_read_kv);
// gemm next qk
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
tSrS);
attention_updater.rescale_o(tOrO);
// last pv gemm
if (smem_pipe_read_kv_cur.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv_cur.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
// wait cur qk gemm
warpgroup_wait<1>();
// mask p
if (masking_step > 0) {
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
Tensor tScS = threadMmaQK.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
if constexpr (!CAUSAL) { // Just masking based on col
if (kv_idx >= kv_len) {
tSrS(i) = AttentionUpdater::fill_value;
}
} else {
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
tSrS(i) = AttentionUpdater::fill_value;
}
}
}
}
// update s (exp(s - m))
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
// gather qk gemm res
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
// r2s fence wgmma
cutlass::arch::fence_view_async_shared();
// make sure tSrS r2s done
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// wait last pv gemm
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
}
// release q
pipeline_q.consumer_release(smem_pipe_read_q);
++smem_pipe_read_q;
// compute last pv
attention_updater.rescale_o(tOrO);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
scale_o = attention_updater.finalize(tSrS);
warpgroup_wait<0>();
// release last kv
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
if (chunk_num_this_seq == 1) {
// norm
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
attention_updater.rescale_o(tOrO);
}
// WG1 write m,d back to gmem
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8t4->row1 row9
const int warp_idx = thread_idx / 32;
#pragma unroll
for (int w_i = 0; w_i < 2; ++w_i) {
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
if (token_idx < qo_len) {
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
}
}
}
} else if (warp_group_idx == 2) {
// consumer 1, compute pv
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
consumer_wait(pipeline_kv, smem_pipe_read_kv);
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
// A: tPsP
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
// rescale
attention_updater.rescale_o(tOrO, scale_o);
if (smem_pipe_read_kv.index() == 0) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV1(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 1) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV2(_, _, _, _0{}), tOrO);
} else if (smem_pipe_read_kv.index() == 2) {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV3(_, _, _, _0{}), tOrO);
} else {
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
tOrV4(_, _, _, _0{}), tOrO);
}
pipeline_kv.consumer_release(smem_pipe_read_kv);
++smem_pipe_read_kv;
}
if (chunk_num_this_seq == 1) {
// norm
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
attention_updater.rescale_o(tOrO, scale_o);
}
}
return;
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_

View File

@@ -0,0 +1,575 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "attention_updater.cuh"
#include "cute/tensor.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
#undef DEBUG_MLA
#endif
// #define DEBUG_MLA
namespace mla_attn {
using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
struct Params {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *padding_offsets;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
int bsz;
int token_num;
int max_seq_len;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_size;
int chunk_num;
int num_blocks_x_int;
float sm_scale;
};
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
const int num_blocks_x = mainloop_params.num_blocks_x[0];
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
using PipelineParamsQ = typename MainloopPipelineQ::Params;
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
} else {
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
}
}();
__syncthreads();
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
if (warp_group_idx == 0) {
// producer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
const int block_id = blockIdx.x;
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
// consumer
if constexpr(USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>();
}
PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
} else {
const int block_id = blockIdx.x;
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop =
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_encoder,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_size,
params.chunk_num,
params.max_draft_token_num
});
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function.
auto kernel =
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
int gridx;
if constexpr(USE_FIXED_BLOCK) {
gridx = multiprocessor_count;
} else {
gridx = params.num_blocks_x_int;
}
dim3 grid_dims = {gridx, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
mainloop_params, epilogue_params
);
if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.padding_offsets,
reinterpret_cast<NV_TYPE*>(params.O),
params.max_seq_len,
params.chunk_num,
params.q_num_head,
params.chunk_size,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num
);
}
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else {
return cudaErrorNotSupported;
}
return cudaSuccess;
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
#include <cuda_runtime.h>
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
namespace mla_attn {
enum class NamedBarriers {
kQueryEmpty = 0,
kValueEmpty = 1,
kWarpSchedulerWG1 = 2,
kWarpSchedulerWG2 = 3,
kWarpSchedulerWG3 = 4,
kPrefetchIndices = 5,
kOdone = 6,
kWG1WG2Sync = 7,
kWG0WG1WG2Sync = 8,
kWG1WG2LastSync = 9,
};
} // namespace mla_attn
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_

View File

@@ -0,0 +1,351 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ATTENTION_HOPPER_UTILS_CUH_
#define ATTENTION_HOPPER_UTILS_CUH_
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <stdlib.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cuda_runtime.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cmath>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "cutlass/fast_math.h"
namespace mla_attn {
using namespace cute;
template <typename TensorT>
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
Tensor tensor_flatten = cute::flatten(tensor);
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
}
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
int64_t h_stride) {
return make_layout(make_shape(nnz, head_dim, num_heads),
make_stride(n_stride, cute::_1{}, h_stride));
}
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
make_coord(offset, _0{}));
auto g_sequence =
make_tensor(g_offset.data(),
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
int head_idx, int offset, int seq_len) {
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
cute::make_shape(shape<0>(m_tensor))));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
return g_tensor;
}
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
// MMA_N))
template <typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
};
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
// MMA_N))
template <typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
make_layout(get<2, 1>(l), get<2>(acc_layout)));
};
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
TensorC& tCrC) {
constexpr bool Is_RS =
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
if constexpr (init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
}
}
#define HOSTDEVICE __host__ __device__
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
};
template <typename T, int Size>
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
const AlignedVector<T, Size>* addr_vec =
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
*vec = *addr_vec;
}
template <typename T, int Size>
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
AlignedVector<T, Size>* addr_vec =
reinterpret_cast<AlignedVector<T, Size>*>(addr);
*addr_vec = vec;
}
template <size_t vec_size, typename T>
struct prefill_softmax_state_t {
AlignedVector<T, vec_size> o;
float m;
float d;
__device__ __forceinline__ void init() {
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&o) + i) = make_half2(0, 0);
}
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
}
}
d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
m = -3.38953e38f;
}
}
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
const float other_m,
const float other_d) {
float m_prev = m, d_prev = d;
m = max(m_prev, other_m);
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d_prev * scale1 + other_d * scale2;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
}
}
__device__ __forceinline__ void normalize() {
const T d_t = static_cast<T>(d);
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
o[i] /= d_t;
}
}
};
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ padding_offsets,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int bsz,
const int max_draft_token_num) {
const int vid = threadIdx.x, ty = threadIdx.y;
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = ori_token_id % max_seq_len;
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
// not need merge
continue;
}
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
m = -5e4f;
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset;
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
d = d * scale1 + d_now * scale2;
#pragma unroll
for (int j = 0; j < vec_size; j++) {
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
}
}
// store ty res
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
md_smem[2 * ty] = m;
md_smem[2 * ty + 1] = d;
__syncthreads();
if (ty == 0) {
// merge bdy
prefill_softmax_state_t<vec_size, T> st;
st.init();
#pragma unroll
for (int i = 0; i < bdy; i++) {
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
st.merge(load_vec, m_tmp, d_tmp);
}
st.normalize();
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
__syncthreads();
}
}
} // namespace mla_attn
#endif // ATTENTION_HOPPER_UTILS_CUH_

Some files were not shown because too many files have changed in this diff Show More