mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
Add DeepGEMM pre-compile tools (#2819)
This tool allows you to compile all possible kernels in advance through the model's config.json, and avoids the situation where uncompiled kernel is encountered and JIT is executed when certain requests arrive.
This commit is contained in:
151
tools/deep_gemm_pre-compile/generate_config.py
Normal file
151
tools/deep_gemm_pre-compile/generate_config.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
|
||||
get_smem_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
|
||||
hidden_size = model_cfg["hidden_size"]
|
||||
intermediate_size = model_cfg["intermediate_size"]
|
||||
moe_intermediate_size = model_cfg["moe_intermediate_size"]
|
||||
num_attention_heads = model_cfg["num_attention_heads"]
|
||||
num_key_value_heads = model_cfg["num_key_value_heads"]
|
||||
head_dim = int(hidden_size / num_attention_heads)
|
||||
gemm_kn_pairs = [
|
||||
# Dense normal gemm
|
||||
[hidden_size, intermediate_size * 2],
|
||||
[intermediate_size, hidden_size],
|
||||
[hidden_size, hidden_size],
|
||||
[hidden_size, (num_attention_heads + num_key_value_heads * 2) * head_dim],
|
||||
]
|
||||
grouped_gemm_contiguous_kn_pairs = [
|
||||
# Moe grouped gemm contiguous
|
||||
[hidden_size, moe_intermediate_size * 2],
|
||||
[moe_intermediate_size, hidden_size],
|
||||
]
|
||||
grouped_gemm_masked_kn_pairs = [
|
||||
# Moe grouped gemm masked
|
||||
[hidden_size, moe_intermediate_size * 2],
|
||||
[moe_intermediate_size, hidden_size],
|
||||
]
|
||||
|
||||
return gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs
|
||||
|
||||
|
||||
def generate_json(
|
||||
kn_pairs: list,
|
||||
moe_num_experts: int,
|
||||
output_path: str,
|
||||
is_grouped_contiguous: bool = False,
|
||||
is_grouped_masked: bool = False,
|
||||
):
|
||||
if not is_grouped_contiguous:
|
||||
BLOCK_MS = [64, 128, 256]
|
||||
else:
|
||||
BLOCK_MS = [128]
|
||||
BLOCK_NS = list(range(16, 129, 8)) + [144, 160]
|
||||
TMA_MULTICAST_CONFIGS = [(1, True), (1, False), (2, True), (2, False)]
|
||||
counter = 0
|
||||
with open(output_path, "a+", encoding="utf-8") as f:
|
||||
for block_m in BLOCK_MS:
|
||||
for block_n in BLOCK_NS:
|
||||
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
|
||||
NUM_STAGES = [4, 3]
|
||||
else:
|
||||
NUM_STAGES = [8, 7, 6, 5, 4, 3]
|
||||
for num_stages in NUM_STAGES:
|
||||
for kn_pair in kn_pairs:
|
||||
smem_config = get_smem_config(
|
||||
num_stages, kn_pair[0], block_m, block_n
|
||||
)
|
||||
for tma_multicast_config in TMA_MULTICAST_CONFIGS:
|
||||
cfg = {
|
||||
"N": kn_pair[1],
|
||||
"K": kn_pair[0],
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"IS_GROUPED_CONTIGUOUS": is_grouped_contiguous,
|
||||
"IS_GROUPED_MASKED": is_grouped_masked,
|
||||
"MOE_NUM_EXPERTS": moe_num_experts,
|
||||
}
|
||||
f.write(json.dumps(cfg) + "\n")
|
||||
counter += 1
|
||||
|
||||
return counter
|
||||
|
||||
|
||||
def main(args):
|
||||
with open(os.path.join(args.model, "config.json"), "r") as f:
|
||||
model_cfg = json.load(f)
|
||||
|
||||
gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs = (
|
||||
generate_kn_pairs(model_cfg)
|
||||
)
|
||||
num_gemm = generate_json(
|
||||
gemm_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
)
|
||||
num_grouped_contiguous = generate_json(
|
||||
grouped_gemm_contiguous_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_contiguous=True,
|
||||
)
|
||||
num_grouped_masked = generate_json(
|
||||
grouped_gemm_masked_kn_pairs,
|
||||
model_cfg["moe_num_experts"],
|
||||
args.output,
|
||||
is_grouped_masked=True,
|
||||
)
|
||||
logger.info(f"Configurations generated and saved to {args.output}")
|
||||
logger.info(f"Generated {num_gemm} gemm configuration.")
|
||||
logger.info(
|
||||
f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration."
|
||||
)
|
||||
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="./deep_gemm_pre_compile_config.jsonl",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
184
tools/deep_gemm_pre-compile/pre_compile.py
Normal file
184
tools/deep_gemm_pre-compile/pre_compile.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from queue import Queue
|
||||
from time import time
|
||||
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.compiler import build
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.template import (
|
||||
cpp_format, generate)
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
|
||||
includes as gemm_includes
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
|
||||
template as gemm_template
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
|
||||
includes as m_grouped_includes
|
||||
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
|
||||
template as m_grouped_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
|
||||
|
||||
|
||||
class CompileWorker(threading.Thread):
|
||||
def __init__(self, queue, pbar):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.pbar = pbar
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
cfg = self.queue.get()
|
||||
if cfg is None:
|
||||
break
|
||||
|
||||
try:
|
||||
logger.debug(f"Compiling for config: {cfg}")
|
||||
keys = {
|
||||
"N": cfg["N"],
|
||||
"K": cfg["K"],
|
||||
"BLOCK_M": cfg["BLOCK_M"],
|
||||
"BLOCK_N": cfg["BLOCK_N"],
|
||||
"SWIZZLE_D_MODE": cfg["SWIZZLE_D_MODE"],
|
||||
"BLOCK_N_PADDING": cfg["BLOCK_N_PADDING"],
|
||||
"NUM_STAGES": cfg["NUM_STAGES"],
|
||||
"NUM_TMA_MULTICAST": cfg["NUM_TMA_MULTICAST"],
|
||||
"IS_TMA_MULTICAST_ON_A": cfg["IS_TMA_MULTICAST_ON_A"],
|
||||
}
|
||||
arg_defs = (
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("m", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
)
|
||||
name = "gemm_fp8_fp8_bf16_nt"
|
||||
includes = gemm_includes
|
||||
template = gemm_template
|
||||
if cfg["IS_GROUPED_CONTIGUOUS"]:
|
||||
keys["GEMM_TYPE"] = "GroupedContiguous"
|
||||
arg_defs = (
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("grouped_layout", paddle.int32),
|
||||
("m", int),
|
||||
("num_groups", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
)
|
||||
if cfg["IS_GROUPED_MASKED"]:
|
||||
keys["GEMM_TYPE"] = "GroupedMasked"
|
||||
arg_defs = (
|
||||
("lhs", paddle.float8_e4m3fn),
|
||||
("lhs_scales", paddle.float32),
|
||||
("rhs", paddle.float8_e4m3fn),
|
||||
("rhs_scales", paddle.float32),
|
||||
("out", paddle.bfloat16),
|
||||
("grouped_layout", paddle.int32),
|
||||
("m", int),
|
||||
("stream", paddle.device.cuda.Stream),
|
||||
("num_sms", int),
|
||||
("smem_size", int),
|
||||
)
|
||||
if cfg["IS_GROUPED_CONTIGUOUS"] or cfg["IS_GROUPED_MASKED"]:
|
||||
keys["NUM_GROUPS"] = int(
|
||||
cfg["MOE_NUM_EXPERTS"] / cfg["EXPERT_PARALLEL"]
|
||||
)
|
||||
includes = m_grouped_includes
|
||||
template = m_grouped_template
|
||||
name = "m_grouped_gemm_fp8_fp8_bf16_nt"
|
||||
|
||||
code = generate(includes, arg_defs, cpp_format(template, keys))
|
||||
build(name, arg_defs, code)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compile config {cfg}: {str(e)}")
|
||||
raise RuntimeError(e)
|
||||
finally:
|
||||
self.pbar.update(1)
|
||||
self.queue.task_done()
|
||||
|
||||
|
||||
def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel: int):
|
||||
with open(config_file, "r") as f:
|
||||
start_time = time()
|
||||
lines = f.readlines()
|
||||
|
||||
queue = Queue()
|
||||
pbar = tqdm(total=len(lines), desc="Compiling")
|
||||
workers = []
|
||||
for _ in range(num_threads):
|
||||
worker = CompileWorker(queue, pbar)
|
||||
worker.start()
|
||||
workers.append(worker)
|
||||
|
||||
for line in lines:
|
||||
cfg = json.loads(line)
|
||||
cfg["EXPERT_PARALLEL"] = expert_parallel
|
||||
queue.put(cfg)
|
||||
|
||||
queue.join()
|
||||
|
||||
for _ in range(num_threads):
|
||||
queue.put(None)
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
|
||||
pbar.close()
|
||||
|
||||
logger.info(f"Total compliation time: {time() - start_time:.2f} seconds")
|
||||
|
||||
|
||||
def main(args):
|
||||
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
default="./deep_gemm_pre_compile_config.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert_parallel",
|
||||
"--ep",
|
||||
type=int,
|
||||
default=8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_threads",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
31
tools/deep_gemm_pre-compile/pre_compile.sh
Normal file
31
tools/deep_gemm_pre-compile/pre_compile.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export PRE_COMPILE_LOG_LEVEL="INFO"
|
||||
export DG_CACHE_DIR=$(pwd)/deep_gemm_cache
|
||||
|
||||
echo DeepGEMM Cache Dir: $DG_CACHE_DIR
|
||||
|
||||
MODEL_PATH=${1:-"/path/to/model"}
|
||||
EXPERT_PARALLEL=${2:-"8"}
|
||||
nproc=$(nproc)
|
||||
|
||||
python generate_config.py \
|
||||
--model $MODEL_PATH \
|
||||
--output=./deep_gemm_pre_compile_config.jsonl
|
||||
|
||||
python pre_compile.py \
|
||||
--config_file=./deep_gemm_pre_compile_config.jsonl \
|
||||
--expert_parallel=$EXPERT_PARALLEL \
|
||||
--num_threads=$nproc
|
Reference in New Issue
Block a user