[Bugs] Fix DeepGEMM pre-compile tools. (#3351)

Fix some miss cache problems.
Add README.md.
This commit is contained in:
GoldPancake
2025-08-15 14:37:49 +08:00
committed by GitHub
parent d4e3a20300
commit 4bd6a9fa7d
4 changed files with 198 additions and 52 deletions

View File

@@ -0,0 +1,61 @@
# DeepGEMM Pre-compilation Tool
This tool provides pre-compilation functionality for DeepGEMM kernels to optimize performance.
## Usage
### 1. Using Shell Script (Recommended)
```bash
bash pre_compile.sh \
[MODEL_PATH] \
[TP_SIZE] \
[EP_SIZE] \
[HAS_SHARED_EXPERTS] \
[OUTPUT_FILE]
```
The script will:
1. Generate configurations
2. Pre-compile all kernels
### 2. Alternative: Manual Steps
If you need more control, you can run the steps manually:
#### Generate Configuration
```bash
python generate_config.py \
--model /path/to/model \
--tensor-parallel-size [TP_SIZE] \
--expert-parallel-size [EP_SIZE] \
--has-shared-experts [True/False] \
--output [CONFIG_FILE]
```
Arguments:
- `--model`: Path to model directory containing config.json
- `--tensor-parallel-size`: Tensor parallel size (default: 1)
- `--expert-parallel-size`: Expert parallel size (default: 8)
- `--has-shared-experts`: Whether model has shared experts (default: False)
- `--output`: Output config file path (default: ./deep_gemm_pre_compile_config.jsonl)
#### Pre-compile Kernels
```bash
python pre_compile.py \
--config-file [CONFIG_FILE] \
--expert-parallel-size [EP_SIZE] \
--num-threads [NUM_THREADS]
```
Arguments:
- `--config-file`: Path to config file generated in step 1
- `--expert-parallel-size`: Expert parallel size (must match step 1)
- `--num-threads`: Number of compilation threads (default: CPU cores)
## Environment Variables
- `PRE_COMPILE_LOG_LEVEL`: Set log level (DEBUG/INFO/WARNING/ERROR)
- `DG_CACHE_DIR`: Cache directory for compiled kernels (default: ./deep_gemm_cache)
## Notes
- For best performance, set `--num-threads` to the number of available CPU cores
- The compilation process may take significant time depending on configuration size
- Compiled kernels will be cached in `DG_CACHE_DIR`

View File

@@ -17,7 +17,7 @@ import json
import logging
import math
import os
from typing import Tuple
from typing import List, Tuple
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import get_smem_config
@@ -27,33 +27,84 @@ logger.addHandler(console_handler)
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
def generate_kn_pairs(args, model_cfg: dict) -> Tuple[List, List, List]:
hidden_size = model_cfg["hidden_size"]
intermediate_size = model_cfg["intermediate_size"]
moe_intermediate_size = model_cfg["moe_intermediate_size"]
num_attention_heads = model_cfg["num_attention_heads"]
num_key_value_heads = model_cfg["num_key_value_heads"]
head_dim = int(hidden_size / num_attention_heads)
gemm_kn_pairs = [
tp_size = args.tensor_parallel_size
ep_size = args.expert_parallel_size
has_shared_experts = args.has_shared_experts.lower() == "true"
gemm_kn_pairs = []
grouped_gemm_contiguous_kn_pairs = []
grouped_gemm_masked_kn_pairs = []
if tp_size > 1 and ep_size == 1:
logger.debug("Generating kn pairs for tensor parallel.")
# Dense normal gemm
[hidden_size, intermediate_size * 2],
[intermediate_size, hidden_size],
[hidden_size, hidden_size],
gemm_kn_pairs.extend(
[
hidden_size,
(num_attention_heads + num_key_value_heads * 2) * head_dim,
],
[int(intermediate_size / tp_size), hidden_size],
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2) / tp_size)],
[hidden_size, int(intermediate_size * 2 / tp_size)],
[int(hidden_size / tp_size), hidden_size],
]
grouped_gemm_contiguous_kn_pairs = [
)
# Moe grouped gemm contiguous
[hidden_size, moe_intermediate_size * 2],
[moe_intermediate_size, hidden_size],
grouped_gemm_contiguous_kn_pairs.extend(
[
[int(moe_intermediate_size / tp_size), hidden_size],
[hidden_size, int(moe_intermediate_size * 2 / tp_size)],
]
grouped_gemm_masked_kn_pairs = [
)
if has_shared_experts:
logger.debug("Generating kn pairs for models with shared experts.")
gemm_kn_pairs.extend(
[
[hidden_size, int(moe_intermediate_size * 4 / tp_size)],
[int(moe_intermediate_size * 2 / tp_size), hidden_size],
]
)
elif tp_size == 1 and ep_size > 1:
logger.debug("Generating kn pairs for expert parallel.")
# Dense normal gemm
gemm_kn_pairs.extend(
[
[intermediate_size, hidden_size],
[hidden_size, int(head_dim * (num_attention_heads + num_key_value_heads * 2))],
[hidden_size, int(intermediate_size * 2)],
[hidden_size, hidden_size],
]
)
# Moe grouped gemm contiguous
grouped_gemm_contiguous_kn_pairs.extend(
[
[moe_intermediate_size, hidden_size],
[hidden_size, int(moe_intermediate_size * 2)],
]
)
# Moe grouped gemm masked
[hidden_size, moe_intermediate_size * 2],
grouped_gemm_masked_kn_pairs.extend(
[
[moe_intermediate_size, hidden_size],
[hidden_size, int(moe_intermediate_size * 2)],
]
)
if has_shared_experts:
logger.debug("Generating kn pairs for models with shared experts.")
gemm_kn_pairs.extend(
[
[hidden_size, int(moe_intermediate_size * 4)],
[int(moe_intermediate_size * 2), hidden_size],
]
)
elif tp_size > 1 and ep_size > 1:
raise ValueError("Not supported to enable EP and TP at the same time for now.")
else:
raise ValueError("Please check the tensor parallel size and expert parallel size.")
return (
gemm_kn_pairs,
@@ -78,7 +129,8 @@ def generate_json(
counter = 0
with open(output_path, "a+", encoding="utf-8") as f:
for block_m in BLOCK_MS:
for block_n in BLOCK_NS:
# NOTES: the block sizes can not be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, BLOCK_NS):
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
NUM_STAGES = [4, 3]
else:
@@ -110,32 +162,42 @@ def generate_json(
def main(args):
with open(os.path.join(args.model, "config.json"), "r") as f:
model_cfg = json.load(f)
logger.debug(
f"TP Size: {args.tensor_parallel_size}, "
f"EP Size: {args.expert_parallel_size}, "
f"has shared experts: {args.has_shared_experts}"
)
logger.info(f"Configurations generated and saved to {args.output}")
(
gemm_kn_pairs,
grouped_gemm_contiguous_kn_pairs,
grouped_gemm_masked_kn_pairs,
) = generate_kn_pairs(model_cfg)
) = generate_kn_pairs(args, model_cfg)
logger.debug(f"GEMM KN pairs: {gemm_kn_pairs}")
logger.debug(f"Grouped GEMM Contiguous KN pairs: {grouped_gemm_contiguous_kn_pairs}")
logger.debug(f"Grouped GEMM Masked KN pairs: {grouped_gemm_masked_kn_pairs}")
if len(gemm_kn_pairs) > 0:
num_gemm = generate_json(
gemm_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
)
logger.info(f"Generated {num_gemm} gemm configuration.")
if len(grouped_gemm_contiguous_kn_pairs) > 0:
num_grouped_contiguous = generate_json(
grouped_gemm_contiguous_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
is_grouped_contiguous=True,
)
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
if len(grouped_gemm_masked_kn_pairs) > 0:
num_grouped_masked = generate_json(
grouped_gemm_masked_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
is_grouped_masked=True,
)
logger.info(f"Configurations generated and saved to {args.output}")
logger.info(f"Generated {num_gemm} gemm configuration.")
logger.info(f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration.")
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
@@ -146,6 +208,23 @@ if __name__ == "__main__":
type=str,
required=True,
)
parser.add_argument(
"--tensor-parallel-size",
"--tp",
type=int,
default=1,
)
parser.add_argument(
"--expert-parallel-size",
"--ep",
type=int,
default=1,
)
parser.add_argument(
"--has-shared-experts",
type=str,
default="False",
)
parser.add_argument(
"--output",
type=str,

View File

@@ -162,25 +162,25 @@ def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel:
def main(args):
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel)
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel_size)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
"--config-file",
type=str,
default="./deep_gemm_pre_compile_config.jsonl",
)
parser.add_argument(
"--expert_parallel",
"--expert-parallel-size",
"--ep",
type=int,
default=8,
)
parser.add_argument(
"--num_threads",
"--num-threads",
type=int,
default=16,
)

View File

@@ -18,14 +18,20 @@ export DG_CACHE_DIR=$(pwd)/deep_gemm_cache
echo DeepGEMM Cache Dir: $DG_CACHE_DIR
MODEL_PATH=${1:-"/path/to/model"}
EXPERT_PARALLEL=${2:-"8"}
TENSOR_PARALLEL_SIZE=${2:-"1"}
EXPERT_PARALLEL_SIZE=${3:-"8"}
HAS_SHARED_EXPERTS=${4:-"False"}
OUTPUT_FILE=${5:-"./deep_gemm_pre_compile_config.jsonl"}
nproc=$(nproc)
python generate_config.py \
--model $MODEL_PATH \
--output=./deep_gemm_pre_compile_config.jsonl
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
--expert-parallel-size $EXPERT_PARALLEL_SIZE \
--has-shared-experts $HAS_SHARED_EXPERTS \
--output $OUTPUT_FILE
python pre_compile.py \
--config_file=./deep_gemm_pre_compile_config.jsonl \
--expert_parallel=$EXPERT_PARALLEL \
--num_threads=$nproc
--config-file $OUTPUT_FILE \
--expert-parallel-size $EXPERT_PARALLEL_SIZE \
--num-threads $nproc