mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Benchmark] Add GEMM & MoE kernel bench (#4809)
This commit is contained in:
@@ -17,9 +17,7 @@ import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.device.cuda.graphs as graphs
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
@@ -30,15 +28,25 @@ from fastdeploy.config import (
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import (
|
||||
BlockWiseFP8Config,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WINT4Config,
|
||||
WINT8Config,
|
||||
)
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from tests.utils import OpPerformanceTester
|
||||
|
||||
paddle.set_default_dtype("bfloat16")
|
||||
paddle.seed(1024)
|
||||
|
||||
QUANT_CONFIG_MAP = {
|
||||
"wint8": WINT8Config({}),
|
||||
"wint4": WINT4Config({}),
|
||||
"block_wise_fp8": BlockWiseFP8Config(weight_block_size=[128, 128]),
|
||||
}
|
||||
|
||||
|
||||
class QuantizedLinearWrapper(paddle.nn.Layer):
|
||||
def __init__(
|
||||
@@ -56,7 +64,7 @@ class QuantizedLinearWrapper(paddle.nn.Layer):
|
||||
self.fd_config = FDConfig(
|
||||
model_config=self.model_config,
|
||||
parallel_config=ParallelConfig({"tensor_parallel_size": self.tp_size}),
|
||||
quant_config=WINT8Config({}) if quant_type == "wint8" else WINT4Config({}),
|
||||
quant_config=QUANT_CONFIG_MAP[quant_type],
|
||||
load_config=LoadConfig({}),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
scheduler_config=SchedulerConfig({}),
|
||||
@@ -158,66 +166,20 @@ class TestQuantizedLinear(unittest.TestCase):
|
||||
)
|
||||
mm = quantized_linear
|
||||
|
||||
print(f"========Method: {type}, Quant Type: {quant_type}=========")
|
||||
print(
|
||||
"{:<15} {:<40} {:<15} {:<15} {:<15}".format(
|
||||
"Batch Size", "Last 5 Times (us)", "Last Time (us)", "TFlops", "TB/s"
|
||||
)
|
||||
tester = OpPerformanceTester(
|
||||
op_name=f"{type}-{quant_type}",
|
||||
op_fn=mm,
|
||||
num_layers=self.model_config.num_hidden_layers,
|
||||
weight_size=weight_size,
|
||||
)
|
||||
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
real_weight_layers = self.model_config.num_hidden_layers
|
||||
linear = [None] * real_weight_layers
|
||||
for i in range(real_weight_layers):
|
||||
linear[i] = mm
|
||||
|
||||
linear_cuda_graphs = [None] * 2000
|
||||
input = [None] * 2000
|
||||
# for idx, bsz in enumerate([1024 * i for i in [1,2,4,8,16,32,64]]):
|
||||
for idx, bsz in enumerate([1, 8, 16, 32, 128, 1024]):
|
||||
|
||||
input[idx] = paddle.rand((bsz, input_size), dtype=paddle.bfloat16)
|
||||
|
||||
def fake_model_run():
|
||||
for j in range(num_layers):
|
||||
out = linear[j % real_weight_layers](input[idx])
|
||||
|
||||
return out
|
||||
|
||||
fake_model_run()
|
||||
|
||||
linear_cuda_graphs[idx] = graphs.CUDAGraph()
|
||||
linear_cuda_graphs[idx].capture_begin()
|
||||
|
||||
fake_model_run()
|
||||
|
||||
linear_cuda_graphs[idx].capture_end()
|
||||
|
||||
num_tests = 20
|
||||
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
for i in range(num_tests):
|
||||
start_events[i].record()
|
||||
|
||||
linear_cuda_graphs[idx].replay()
|
||||
|
||||
end_events[i].record()
|
||||
paddle.device.synchronize()
|
||||
|
||||
times = np.array([round(s.elapsed_time(e), 2) for s, e in zip(start_events, end_events)])[1:]
|
||||
times = times * 1e3 / num_layers
|
||||
times = np.array([round(time, 2) for time in times])
|
||||
last_5_times = times[-5:]
|
||||
last_time = times[-1] # us
|
||||
|
||||
flops = 2 * bsz * weight_size
|
||||
memory = weight_size
|
||||
tfloaps = round(flops / (1e12) / (last_time * 1e-6), 1)
|
||||
tbps = round(memory / (1e12) / (last_time * 1e-6), 1)
|
||||
print("{:<15} {:<40} {:<15} {:<15} {:<15}".format(bsz, str(last_5_times), last_time, tfloaps, tbps))
|
||||
tester.benchmark(
|
||||
input_size=input_size,
|
||||
batch_sizes=[1, 8, 16, 32, 128],
|
||||
)
|
||||
|
||||
def test_quantized_linear(self):
|
||||
for type in ["qkv_proj", "o_proj", "out_proj+qkv_proj"]:
|
||||
for type in ["qkv_proj", "o_proj"]:
|
||||
for quant_type in ["wint4", "wint8"]:
|
||||
for use_machete in ["0", "1"]:
|
||||
os.environ["FD_USE_MACHETE"] = use_machete
|
||||
|
||||
248
tests/layers/test_w4a8_moe.py
Normal file
248
tests/layers/test_w4a8_moe.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.worker.worker_process import init_distributed_environment
|
||||
from tests.utils import OpPerformanceTester
|
||||
|
||||
paddle.set_default_dtype("bfloat16")
|
||||
|
||||
|
||||
class FuseMoEWrapper(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
prefix: str = "layer0",
|
||||
nnodes: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.ep_size = ep_size
|
||||
self.ep_rank = ep_rank
|
||||
|
||||
self.prefix = prefix
|
||||
self.fd_config = FDConfig(
|
||||
model_config=self.model_config,
|
||||
parallel_config=ParallelConfig(
|
||||
{
|
||||
"tensor_parallel_size": self.tp_size,
|
||||
"expert_parallel_size": self.ep_size,
|
||||
"expert_parallel_rank": self.ep_rank,
|
||||
"data_parallel_size": self.ep_size,
|
||||
}
|
||||
),
|
||||
quant_config=W4A8Config(is_permuted=False, hadamard_block_size=128),
|
||||
# quant_config=W4AFP8Config(weight_scale_dict=None, act_scale_dict=None, is_permuted=False, hadamard_block_size=128),
|
||||
scheduler_config=SchedulerConfig({}),
|
||||
cache_config=CacheConfig({}),
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips=",".join(["0"] * nnodes),
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank
|
||||
self.fd_config.parallel_config.expert_parallel_size = self.ep_size
|
||||
if self.ep_size > 1:
|
||||
self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
|
||||
self.fd_config.scheduler_config.splitwise_role = "mixed"
|
||||
self.fd_config.model_config.moe_phase.phase = "decode"
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{self.prefix}.gate.weight",
|
||||
"gate_correction_bias_key": f"{self.prefix}.moe_statics.e_score_correction_bias",
|
||||
"up_gate_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.down_proj.weight",
|
||||
"up_gate_proj_expert_weight_scale_key": f"{self.prefix}.experts.{{}}.up_gate_proj.weight_scale",
|
||||
"down_proj_expert_weight_scale_key": f"{self.prefix}.experts.{{}}.down_proj.weight_scale",
|
||||
"up_gate_proj_expert_in_scale_key": f"{self.prefix}.experts.{{}}.up_gate_proj.activation_scale",
|
||||
"down_proj_expert_in_scale_key": f"{self.prefix}.experts.{{}}.down_proj.activation_scale",
|
||||
}
|
||||
|
||||
self.fused_moe = FusedMoE(
|
||||
fd_config=self.fd_config,
|
||||
moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=self.fd_config.model_config.moe_num_experts,
|
||||
top_k=self.fd_config.model_config.moe_k,
|
||||
# avoiding invoke clean_low_latency_buffer in mixed ep.
|
||||
layer_idx=666,
|
||||
weight_key_map=weight_key_map,
|
||||
topk_method="noaux_tc",
|
||||
topk_group=4,
|
||||
n_group=8,
|
||||
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
|
||||
# gate_correction_bias = gate_correction_bias_real_data
|
||||
)
|
||||
self.pack_num = 2
|
||||
moe_layer = self.fused_moe
|
||||
|
||||
up_gate_proj_weight_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.hidden_size // self.pack_num,
|
||||
moe_layer.moe_intermediate_size * 2,
|
||||
]
|
||||
down_proj_weight_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.moe_intermediate_size // self.pack_num,
|
||||
moe_layer.hidden_size,
|
||||
]
|
||||
up_gate_proj_weight_scale_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.moe_intermediate_size * 2,
|
||||
]
|
||||
down_proj_weight_scale_shape = [
|
||||
moe_layer.num_local_experts,
|
||||
moe_layer.hidden_size,
|
||||
]
|
||||
|
||||
up_gate_proj_weight = (paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) * 100).cast(paddle.int8)
|
||||
down_proj_weight = (paddle.randn(down_proj_weight_shape, paddle.bfloat16) * 100).cast(paddle.int8)
|
||||
|
||||
up_gate_proj_weight_scale = paddle.randn(up_gate_proj_weight_scale_shape, paddle.bfloat16)
|
||||
down_proj_weight_scale = paddle.randn(down_proj_weight_scale_shape, paddle.bfloat16)
|
||||
|
||||
up_gate_proj_in_scale = paddle.randn([self.fd_config.model_config.moe_num_experts, 1], paddle.float32)
|
||||
down_proj_in_scale = paddle.randn([self.fd_config.model_config.moe_num_experts, 1], paddle.float32)
|
||||
|
||||
local_expert_ids = list(
|
||||
range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts)
|
||||
)
|
||||
state_dict = {}
|
||||
up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key")
|
||||
up_gate_proj_expert_weight_scale_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key")
|
||||
up_gate_proj_expert_in_scale_key = moe_layer.weight_key_map.get("up_gate_proj_expert_in_scale_key")
|
||||
down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key")
|
||||
down_proj_expert_weight_scale_key = moe_layer.weight_key_map.get("down_proj_expert_weight_scale_key")
|
||||
down_proj_expert_in_scale_key = moe_layer.weight_key_map.get("down_proj_expert_in_scale_key")
|
||||
|
||||
for expert_idx in local_expert_ids:
|
||||
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
||||
up_gate_proj_expert_weight_scale_key_name = up_gate_proj_expert_weight_scale_key.format(expert_idx)
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
down_proj_expert_weight_scale_key_name = down_proj_expert_weight_scale_key.format(expert_idx)
|
||||
|
||||
state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[
|
||||
expert_idx - moe_layer.expert_id_offset
|
||||
]
|
||||
state_dict[up_gate_proj_expert_weight_scale_key_name] = up_gate_proj_weight_scale[
|
||||
expert_idx - moe_layer.expert_id_offset
|
||||
]
|
||||
state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset]
|
||||
state_dict[down_proj_expert_weight_scale_key_name] = down_proj_weight_scale[
|
||||
expert_idx - moe_layer.expert_id_offset
|
||||
]
|
||||
|
||||
for expert_idx in range(self.fd_config.model_config.moe_num_experts):
|
||||
up_gate_proj_expert_in_scale_key_name = up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||
down_proj_expert_in_scale_key_name = down_proj_expert_in_scale_key.format(expert_idx)
|
||||
state_dict[up_gate_proj_expert_in_scale_key_name] = up_gate_proj_in_scale[expert_idx]
|
||||
state_dict[down_proj_expert_in_scale_key_name] = down_proj_in_scale[expert_idx]
|
||||
|
||||
moe_layer.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class TestW4A8FusedMoE(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.architectures = ["Ernie4_5_MoeForCausalLM"]
|
||||
self.hidden_size = 8192
|
||||
self.moe_intermediate_size = 3584
|
||||
self.moe_num_experts = 64
|
||||
self.moe_k = 8
|
||||
self.hidden_act = "silu"
|
||||
self.num_attention_heads = 64
|
||||
self.num_hidden_layers = 54
|
||||
self.model_config = self.build_model_config()
|
||||
|
||||
def build_model_config(self) -> ModelConfig:
|
||||
model_name_or_path = self.build_config_json()
|
||||
return ModelConfig(
|
||||
{
|
||||
"model": model_name_or_path,
|
||||
"max_model_len": 2048,
|
||||
}
|
||||
)
|
||||
|
||||
def build_config_json(self) -> str:
|
||||
config_dict = {
|
||||
"architectures": self.architectures,
|
||||
"hidden_size": self.hidden_size,
|
||||
"moe_intermediate_size": self.moe_intermediate_size,
|
||||
"moe_num_experts": self.moe_num_experts,
|
||||
"moe_k": self.moe_k,
|
||||
"hidden_act": self.hidden_act,
|
||||
"num_attention_heads": self.num_attention_heads,
|
||||
"num_hidden_layers": self.num_hidden_layers,
|
||||
"dtype": "bfloat16",
|
||||
}
|
||||
|
||||
tmp_dir = f"./tmp_w4a8_moe_{paddle.distributed.get_rank()}"
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
with open(f"./{tmp_dir}/config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
|
||||
return self.model_name_or_path
|
||||
|
||||
def test_fused_moe(self):
|
||||
init_distributed_environment()
|
||||
|
||||
gating = paddle.nn.Linear(self.model_config.hidden_size, self.model_config.moe_num_experts)
|
||||
gating.to(dtype=paddle.float32) # it's dtype is bfloat16 default, but the forward input is float32
|
||||
gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32))
|
||||
|
||||
# ep_size = paddle.distributed.get_world_size()
|
||||
# ep_rank = paddle.distributed.get_rank()
|
||||
ep_size = 1
|
||||
ep_rank = 0
|
||||
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
|
||||
nnodes = (ep_size + 7) // 8
|
||||
|
||||
# 这行代码必须保留,否则影响均匀性!
|
||||
paddle.seed(ep_rank + 100)
|
||||
|
||||
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes).fused_moe
|
||||
weight_size = fused_moe.top_k * fused_moe.hidden_size * fused_moe.moe_intermediate_size * 3 / 2
|
||||
|
||||
tester = OpPerformanceTester(
|
||||
op_name="w4a8-moe",
|
||||
op_fn=fused_moe,
|
||||
num_layers=self.model_config.num_hidden_layers,
|
||||
weight_size=weight_size,
|
||||
gate=gating,
|
||||
)
|
||||
|
||||
tester.benchmark(
|
||||
input_size=self.model_config.hidden_size,
|
||||
batch_sizes=[10, 20, 40, 60, 80, 100, 128],
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.model_name_or_path:
|
||||
print("Remove tmp model config file")
|
||||
shutil.rmtree(self.model_name_or_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -16,6 +16,10 @@
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.device.cuda.graphs as graphs
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
@@ -68,3 +72,66 @@ def get_default_test_fd_config():
|
||||
test_mode=True,
|
||||
)
|
||||
return fd_config
|
||||
|
||||
|
||||
class OpPerformanceTester:
|
||||
def __init__(self, op_name, op_fn, num_layers=20, weight_size=None, gate=None):
|
||||
self.op_name = op_name
|
||||
self.op_fn = op_fn
|
||||
self.num_layers = num_layers
|
||||
self.weight_size = weight_size
|
||||
self.gate = gate
|
||||
|
||||
def _fake_model_run(self, x):
|
||||
for j in range(self.num_layers):
|
||||
if self.gate:
|
||||
out = self.op_fn(x, self.gate)
|
||||
else:
|
||||
out = self.op_fn(x)
|
||||
return out
|
||||
|
||||
def benchmark(self, input_size, batch_sizes, dtype="bfloat16", num_warmup=1, num_tests=10):
|
||||
print(f"======== {self.op_name} Performance ========")
|
||||
print(
|
||||
"{:<15} {:<40} {:<15} {:<15} {:<15}".format(
|
||||
"Batch Size", "Last 5 Times (us)", "Last Time (us)", "TFlops", "TB/s"
|
||||
)
|
||||
)
|
||||
|
||||
for idx, bsz in enumerate(batch_sizes):
|
||||
x = paddle.rand((bsz, input_size), dtype=dtype)
|
||||
|
||||
self._fake_model_run(x)
|
||||
|
||||
graph = graphs.CUDAGraph()
|
||||
graph.capture_begin()
|
||||
self._fake_model_run(x)
|
||||
graph.capture_end()
|
||||
|
||||
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
|
||||
|
||||
for i in range(num_tests):
|
||||
start_events[i].record()
|
||||
graph.replay()
|
||||
end_events[i].record()
|
||||
|
||||
paddle.device.synchronize()
|
||||
|
||||
times = np.array([round(s.elapsed_time(e), 2) for s, e in zip(start_events, end_events)])[num_warmup:]
|
||||
times = times * 1e3 / self.num_layers # us / layer
|
||||
times = np.array([round(time, 2) for time in times])
|
||||
last_5_times = times[-5:]
|
||||
last_time = times[-1]
|
||||
|
||||
tfloaps = None
|
||||
tbps = None
|
||||
if self.weight_size:
|
||||
flops = 2 * bsz * self.weight_size
|
||||
memory = self.weight_size
|
||||
tfloaps = round(flops / 1e12 / (last_time * 1e-6), 1)
|
||||
tbps = round(memory / 1e12 / (last_time * 1e-6), 1)
|
||||
|
||||
print("{:<15} {:<40} {:<15} {:<15} {:<15}".format(bsz, str(last_5_times), last_time, tfloaps, tbps))
|
||||
else:
|
||||
print("{:<15} {:<40} {:<15}".format(bsz, str(last_5_times), last_time))
|
||||
|
||||
Reference in New Issue
Block a user