mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -19,8 +19,7 @@ import math
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
|
||||
get_smem_config
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import get_smem_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console_handler = logging.StreamHandler()
|
||||
@@ -40,7 +39,10 @@ def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
|
||||
[hidden_size, intermediate_size * 2],
|
||||
[intermediate_size, hidden_size],
|
||||
[hidden_size, hidden_size],
|
||||
[hidden_size, (num_attention_heads + num_key_value_heads * 2) * head_dim],
|
||||
[
|
||||
hidden_size,
|
||||
(num_attention_heads + num_key_value_heads * 2) * head_dim,
|
||||
],
|
||||
]
|
||||
grouped_gemm_contiguous_kn_pairs = [
|
||||
# Moe grouped gemm contiguous
|
||||
@@ -53,7 +55,11 @@ def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
|
||||
[moe_intermediate_size, hidden_size],
|
||||
]
|
||||
|
||||
return gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs
|
||||
return (
|
||||
gemm_kn_pairs,
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
)
|
||||
|
||||
|
||||
def generate_json(
|
||||
@@ -79,9 +85,7 @@ def generate_json(
|
||||
NUM_STAGES = [8, 7, 6, 5, 4, 3]
|
||||
for num_stages in NUM_STAGES:
|
||||
for kn_pair in kn_pairs:
|
||||
smem_config = get_smem_config(
|
||||
num_stages, kn_pair[0], block_m, block_n
|
||||
)
|
||||
smem_config = get_smem_config(num_stages, kn_pair[0], block_m, block_n)
|
||||
for tma_multicast_config in TMA_MULTICAST_CONFIGS:
|
||||
cfg = {
|
||||
"N": kn_pair[1],
|
||||
@@ -107,9 +111,11 @@ def main(args):
|
||||
with open(os.path.join(args.model, "config.json"), "r") as f:
|
||||
model_cfg = json.load(f)
|
||||
|
||||
gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs = (
|
||||
generate_kn_pairs(model_cfg)
|
||||
)
|
||||
(
|
||||
gemm_kn_pairs,
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
) = generate_kn_pairs(model_cfg)
|
||||
num_gemm = generate_json(
|
||||
gemm_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
@@ -129,9 +135,7 @@ def main(args):
|
||||
)
|
||||
logger.info(f"Configurations generated and saved to {args.output}")
|
||||
logger.info(f"Generated {num_gemm} gemm configuration.")
|
||||
logger.info(
|
||||
f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration."
|
||||
)
|
||||
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
|
||||
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user