mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Bugs] Fix DeepGEMM pre-compile tools. (#3351)
Fix some miss cache problems. Add README.md.
This commit is contained in:
@@ -17,7 +17,7 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import get_smem_config
|
||||
|
||||
@@ -27,33 +27,84 @@ logger.addHandler(console_handler)
|
||||
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
|
||||
def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]:
|
||||
hidden_size = model_cfg["hidden_size"]
|
||||
intermediate_size = model_cfg["intermediate_size"]
|
||||
moe_intermediate_size = model_cfg["moe_intermediate_size"]
|
||||
num_attention_heads = model_cfg["num_attention_heads"]
|
||||
num_key_value_heads = model_cfg["num_key_value_heads"]
|
||||
head_dim = int(hidden_size / num_attention_heads)
|
||||
gemm_kn_pairs = [
|
||||
tp_size = args.tensor_parallel_size
|
||||
ep_size = args.expert_parallel_size
|
||||
has_shared_experts = args.has_shared_experts.lower() == "true"
|
||||
|
||||
gemm_kn_pairs = []
|
||||
grouped_gemm_contiguous_kn_pairs = []
|
||||
grouped_gemm_masked_kn_pairs = []
|
||||
if tp_size > 1 and ep_size == 1:
|
||||
logger.debug("Generating kn pairs for tensor parallel.")
|
||||
# Dense normal gemm
|
||||
[hidden_size, intermediate_size * 2],
|
||||
[intermediate_size, hidden_size],
|
||||
[hidden_size, hidden_size],
|
||||
[
|
||||
hidden_size,
|
||||
(num_attention_heads + num_key_value_heads * 2) * head_dim,
|
||||
],
|
||||
]
|
||||
grouped_gemm_contiguous_kn_pairs = [
|
||||
gemm_kn_pairs.extend(
|
||||
[
|
||||
[int(intermediate_size / tp_size), hidden_size],
|
||||
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)],
|
||||
[hidden_size, int(intermediate_size * 2 / tp_size)],
|
||||
[int(hidden_size / tp_size), hidden_size],
|
||||
]
|
||||
)
|
||||
|
||||
# Moe grouped gemm contiguous
|
||||
[hidden_size, moe_intermediate_size * 2],
|
||||
[moe_intermediate_size, hidden_size],
|
||||
]
|
||||
grouped_gemm_masked_kn_pairs = [
|
||||
grouped_gemm_contiguous_kn_pairs.extend(
|
||||
[
|
||||
[int(moe_intermediate_size / tp_size), hidden_size],
|
||||
[hidden_size, int(moe_intermediate_size * 2 / tp_size)],
|
||||
]
|
||||
)
|
||||
if has_shared_experts:
|
||||
logger.debug("Generating kn pairs for models with shared experts.")
|
||||
gemm_kn_pairs.extend(
|
||||
[
|
||||
[hidden_size, int(moe_intermediate_size * 4 / tp_size)],
|
||||
[int(moe_intermediate_size * 2 / tp_size), hidden_size],
|
||||
]
|
||||
)
|
||||
elif tp_size == 1 and ep_size > 1:
|
||||
logger.debug("Generating kn pairs for expert parallel.")
|
||||
# Dense normal gemm
|
||||
gemm_kn_pairs.extend(
|
||||
[
|
||||
[intermediate_size, hidden_size],
|
||||
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2))],
|
||||
[hidden_size, int(intermediate_size * 2)],
|
||||
[hidden_size, hidden_size],
|
||||
]
|
||||
)
|
||||
# Moe grouped gemm contiguous
|
||||
grouped_gemm_contiguous_kn_pairs.extend(
|
||||
[
|
||||
[moe_intermediate_size, hidden_size],
|
||||
[hidden_size, int(moe_intermediate_size * 2)],
|
||||
]
|
||||
)
|
||||
# Moe grouped gemm masked
|
||||
[hidden_size, moe_intermediate_size * 2],
|
||||
[moe_intermediate_size, hidden_size],
|
||||
]
|
||||
grouped_gemm_masked_kn_pairs.extend(
|
||||
[
|
||||
[moe_intermediate_size, hidden_size],
|
||||
[hidden_size, int(moe_intermediate_size * 2)],
|
||||
]
|
||||
)
|
||||
if has_shared_experts:
|
||||
logger.debug("Generating kn pairs for models with shared experts.")
|
||||
gemm_kn_pairs.extend(
|
||||
[
|
||||
[hidden_size, int(moe_intermediate_size * 4)],
|
||||
[int(moe_intermediate_size * 2), hidden_size],
|
||||
]
|
||||
)
|
||||
elif tp_size > 1 and ep_size > 1:
|
||||
raise ValueError("Not supported to enable EP and TP at the same time for now.")
|
||||
else:
|
||||
raise ValueError("Please check the tensor parallel size and expert parallel size.")
|
||||
|
||||
return (
|
||||
gemm_kn_pairs,
|
||||
@@ -78,7 +129,8 @@ def generate_json(
|
||||
counter = 0
|
||||
with open(output_path, "a+", encoding="utf-8") as f:
|
||||
for block_m in BLOCK_MS:
|
||||
for block_n in BLOCK_NS:
|
||||
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, BLOCK_NS):
|
||||
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
|
||||
NUM_STAGES = [4, 3]
|
||||
else:
|
||||
@@ -110,33 +162,43 @@ def generate_json(
|
||||
def main(args):
|
||||
with open(os.path.join(args.model, "config.json"), "r") as f:
|
||||
model_cfg = json.load(f)
|
||||
|
||||
logger.debug(
|
||||
f"TP Size: {args.tensor_parallel_size}, "
|
||||
f"EP Size: {args.expert_parallel_size}, "
|
||||
f"has shared experts: {args.has_shared_experts}"
|
||||
)
|
||||
logger.info(f"Configurations generated and saved to {args.output}")
|
||||
(
|
||||
gemm_kn_pairs,
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
) = generate_kn_pairs(model_cfg)
|
||||
num_gemm = generate_json(
|
||||
gemm_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
)
|
||||
num_grouped_contiguous = generate_json(
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_contiguous=True,
|
||||
)
|
||||
num_grouped_masked = generate_json(
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_masked=True,
|
||||
)
|
||||
logger.info(f"Configurations generated and saved to {args.output}")
|
||||
logger.info(f"Generated {num_gemm} gemm configuration.")
|
||||
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
|
||||
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
|
||||
) = generate_kn_pairs(args, model_cfg)
|
||||
logger.debug(f"GEMM KN pairs: {gemm_kn_pairs}")
|
||||
logger.debug(f"Grouped GEMM Contiguous KN pairs: {grouped_gemm_contiguous_kn_pairs}")
|
||||
logger.debug(f"Grouped GEMM Masked KN pairs: {grouped_gemm_masked_kn_pairs}")
|
||||
if len(gemm_kn_pairs) > 0:
|
||||
num_gemm = generate_json(
|
||||
gemm_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
)
|
||||
logger.info(f"Generated {num_gemm} gemm configuration.")
|
||||
if len(grouped_gemm_contiguous_kn_pairs) > 0:
|
||||
num_grouped_contiguous = generate_json(
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_contiguous=True,
|
||||
)
|
||||
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
|
||||
if len(grouped_gemm_masked_kn_pairs) > 0:
|
||||
num_grouped_masked = generate_json(
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_masked=True,
|
||||
)
|
||||
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -146,6 +208,23 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
"--tp",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert-parallel-size",
|
||||
"--ep",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--has-shared-experts",
|
||||
type=str,
|
||||
default="False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
|
Reference in New Issue
Block a user