polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -28,8 +28,13 @@ def group_gemm(
scale: paddle.Tensor,
output: paddle.Tensor,
):
assert (input.dim() == 2 and tokens_expert_prefix_sum.dim() == 1
and weight.dim() == 3 and scale.dim() == 2 and output.dim() == 2)
assert (
input.dim() == 2
and tokens_expert_prefix_sum.dim() == 1
and weight.dim() == 3
and scale.dim() == 2
and output.dim() == 2
)
num_tokens = input.shape[0]
dim_in = input.shape[1]
dim_out = weight.shape[1]
@@ -66,7 +71,8 @@ def group_gemm(
weight_i,
weight_scale=scale_i,
weight_dtype="int8",
group_size=-1)
group_size=-1,
)
def iluvatar_moe_expert_ffn(
@@ -90,13 +96,24 @@ def iluvatar_moe_expert_ffn(
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
up_gate_proj_output = paddle.empty([permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype)
group_gemm(permute_input, tokens_expert_prefix_sum_cpu, up_gate_proj_weight,
up_gate_proj_scale, up_gate_proj_output)
up_gate_proj_output = paddle.empty(
[permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype,
)
group_gemm(
permute_input,
tokens_expert_prefix_sum_cpu,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_output,
)
act_out = swiglu(up_gate_proj_output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]],
dtype=act_out.dtype)
group_gemm(act_out, tokens_expert_prefix_sum_cpu, down_proj_weight, down_proj_scale,
output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype)
group_gemm(
act_out,
tokens_expert_prefix_sum_cpu,
down_proj_weight,
down_proj_scale,
output,
)
return output