mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* split MultiQueryAppendC8Attention template_instantiation * update setup_ops.py * fix ci * fix bug
155 lines
5.6 KiB
Python
155 lines
5.6 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""generate multiquery_attention_c8_kernel template instantiation."""
|
|
|
|
from pathlib import Path
|
|
|
|
TEMPLATE_DIR = Path("gpu_ops/append_attn/template_instantiation/autogen")
|
|
TEMPLATE_DIR.mkdir(exist_ok=True)
|
|
|
|
DISPATCH_PARAMS = {
|
|
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
|
"HEAD_DIM": [128],
|
|
"BLOCK_SIZE": [64],
|
|
"CAUSAL": [0, 1],
|
|
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
|
"ENABLE_PREFILL": [0, 1],
|
|
"IsFP8": [0, 1],
|
|
"IsDynamicC8": [0, 1],
|
|
}
|
|
|
|
DATA_TYPE_COMBINATIONS = [
|
|
("paddle::float16", "paddle::float16", "float16_float16"),
|
|
("paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"),
|
|
("paddle::float16", "int8_t", "float16_int8"),
|
|
("paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"),
|
|
("paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"),
|
|
("paddle::bfloat16", "int8_t", "bfloat16_int8"),
|
|
]
|
|
|
|
MAX_INSTANCES_PER_FILE = 60
|
|
|
|
|
|
def get_num_warp_q(block_shape_q):
|
|
if block_shape_q <= 32:
|
|
return 1
|
|
else:
|
|
return 4
|
|
|
|
|
|
def generate_file_header():
|
|
return """// Generated by autogen_template_instantiation.py - Do not edit.
|
|
|
|
#pragma once
|
|
|
|
#include "../../multiquery_attention_c8_impl.cuh"
|
|
"""
|
|
|
|
|
|
def generate_template_instantiation(t_in, t_out, params):
|
|
num_warp_q = get_num_warp_q(params["BLOCK_SHAPE_Q"])
|
|
template_args = f"<{t_in}, {params['GROUP_SIZE']}, {params['HEAD_DIM']}, {params['BLOCK_SIZE']}, {params['CAUSAL']}, {params['BLOCK_SHAPE_Q']}, {num_warp_q}, {t_out}, {params['ENABLE_PREFILL']}, {params['IsFP8']}, {params['IsDynamicC8']}>"
|
|
|
|
return f"""
|
|
template void MultiQueryAppendC8Attention{template_args}(
|
|
const AppendAttnMetaData &meta_data,
|
|
const paddle::Tensor &qkv,
|
|
const paddle::Tensor &cache_k,
|
|
const paddle::Tensor &cache_v,
|
|
const paddle::optional<paddle::Tensor> &attn_mask,
|
|
const paddle::Tensor &cache_k_scale,
|
|
const paddle::Tensor &cache_v_scale,
|
|
const paddle::optional<paddle::Tensor> &shift_bias,
|
|
const paddle::optional<paddle::Tensor> &smooth_weight,
|
|
const paddle::Tensor &seq_lens_q,
|
|
const paddle::Tensor &seq_lens_kv,
|
|
const paddle::Tensor &seq_lens_encoder,
|
|
const paddle::Tensor &batch_id_per_token,
|
|
const paddle::Tensor &cu_seqlens_q,
|
|
const paddle::Tensor &block_table,
|
|
const paddle::Tensor &batch_ids,
|
|
const paddle::Tensor &tile_ids_per_batch,
|
|
const int num_blocks_x_cpu,
|
|
const int max_seq_len,
|
|
const int max_dec_len,
|
|
const float quant_max_bound,
|
|
const float quant_min_bound,
|
|
const float in_scale,
|
|
const int max_partition_size,
|
|
const int encoder_max_partition_size,
|
|
const int speculate_max_draft_token_num,
|
|
const bool is_decoder,
|
|
cudaStream_t &stream,
|
|
paddle::Tensor *out);
|
|
|
|
"""
|
|
|
|
|
|
def generate_combinations_for_type(t_in, t_out):
|
|
combinations = []
|
|
for group_size in DISPATCH_PARAMS["GROUP_SIZE"]:
|
|
for head_dim in DISPATCH_PARAMS["HEAD_DIM"]:
|
|
for block_size in DISPATCH_PARAMS["BLOCK_SIZE"]:
|
|
for causal in DISPATCH_PARAMS["CAUSAL"]:
|
|
for block_shape_q in DISPATCH_PARAMS["BLOCK_SHAPE_Q"]:
|
|
for enable_prefill in DISPATCH_PARAMS["ENABLE_PREFILL"]:
|
|
for is_fp8 in DISPATCH_PARAMS["IsFP8"]:
|
|
for is_dynamic_c8 in DISPATCH_PARAMS["IsDynamicC8"]:
|
|
params = {
|
|
"GROUP_SIZE": group_size,
|
|
"HEAD_DIM": head_dim,
|
|
"BLOCK_SIZE": block_size,
|
|
"CAUSAL": causal,
|
|
"BLOCK_SHAPE_Q": block_shape_q,
|
|
"ENABLE_PREFILL": enable_prefill,
|
|
"IsFP8": is_fp8,
|
|
"IsDynamicC8": is_dynamic_c8,
|
|
}
|
|
combinations.append(params)
|
|
|
|
return combinations
|
|
|
|
|
|
def split_combinations(combinations, max_per_file):
|
|
chunks = []
|
|
for i in range(0, len(combinations), max_per_file):
|
|
chunk = combinations[i : i + max_per_file]
|
|
chunks.append(chunk)
|
|
return chunks
|
|
|
|
|
|
def generate_file_content(t_in, t_out, t_out_name, file_index, combinations):
|
|
content = generate_file_header()
|
|
for params in combinations:
|
|
content += generate_template_instantiation(t_in, t_out, params)
|
|
|
|
return content
|
|
|
|
|
|
def main():
|
|
for t_in, t_out, t_out_name in DATA_TYPE_COMBINATIONS:
|
|
combinations = generate_combinations_for_type(t_in, t_out)
|
|
if combinations:
|
|
chunks = split_combinations(combinations, MAX_INSTANCES_PER_FILE)
|
|
for i, chunk in enumerate(chunks):
|
|
filename = f"multiquery_attention_c8_{t_out_name}_part_{i:02d}.cu"
|
|
filepath = TEMPLATE_DIR / filename
|
|
content = generate_file_content(t_in, t_out, t_out_name, i, chunk)
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
f.write(content)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|