mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00

Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* improve per_token_quant_fp8 performance * support moe wfp8apf8 * check glm test * fix noaux_tc op in cudagraph, support noaux_tc return the correct * check * check inf and overwrite score in noaux_tc --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
196 lines
7.5 KiB
Python
196 lines
7.5 KiB
Python
"""
|
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit()
|
|
def fused_moe_kernel_paddle(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
a_scale_ptr,
|
|
b_scale_ptr,
|
|
topk_weights_ptr,
|
|
sorted_token_ids_ptr,
|
|
expert_ids_ptr,
|
|
num_tokens_post_padded_ptr,
|
|
# Matrix dimensions
|
|
max_possible_num_post_padded,
|
|
num_valid_tokens,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
stride_am: tl.constexpr,
|
|
stride_ak: tl.constexpr,
|
|
stride_be: tl.constexpr,
|
|
stride_bk: tl.constexpr,
|
|
stride_bn: tl.constexpr,
|
|
stride_cm: tl.constexpr,
|
|
stride_cn: tl.constexpr,
|
|
stride_asm: tl.constexpr,
|
|
stride_ask: tl.constexpr,
|
|
stride_bse: tl.constexpr,
|
|
stride_bsk: tl.constexpr,
|
|
stride_bsn: tl.constexpr,
|
|
# Block size for block-wise fp8 quantization
|
|
group_n: tl.constexpr,
|
|
group_k: tl.constexpr,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
compute_type_enum: tl.constexpr,
|
|
use_fp8_w8a8: tl.constexpr,
|
|
use_int8_w8a16: tl.constexpr,
|
|
per_channel_quant: tl.constexpr,
|
|
even_Ks: tl.constexpr,
|
|
):
|
|
"""
|
|
|
|
Key Parameters:
|
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
|
be any shape representing batches and K is the feature dimension of
|
|
each token.
|
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
|
the number of experts, K is the input feature dimension, and N is
|
|
the output feature dimension.
|
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
|
total number of tokens post padding, topk is the number of times
|
|
each token is repeated, and N is the output feature dimension.
|
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
|
repeated topk times and arranged by the expert index they are
|
|
assigned to.
|
|
- expert_ids: A tensor containing the indices of the expert for each
|
|
block. It determines which expert matrix from B should be used for
|
|
each block in A.
|
|
This kernel performs the multiplication of a token by its corresponding
|
|
expert matrix as determined by `expert_ids`. The sorting of
|
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
|
multiplication across different blocks processed by the same expert.
|
|
"""
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
assert compute_type_enum == 1
|
|
compute_type = tl.bfloat16
|
|
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
|
return
|
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
|
token_mask = offs_token < num_valid_tokens
|
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
|
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
if use_int8_w8a16:
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
|
|
if use_fp8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
offs_bsn = offs_bn // group_n
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
|
# channel-wise
|
|
elif per_channel_quant:
|
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
# Load per-token scale for activations
|
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
|
else:
|
|
# (Zkk): every expert has one activation scale and weight scale.
|
|
a_scale = tl.load(a_scale_ptr + off_experts)
|
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
if even_Ks:
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=token_mask[:, None],
|
|
other=0.0,
|
|
)
|
|
b = tl.load(b_ptrs, cache_modifier=".cv", eviction_policy="evict_first")
|
|
else:
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0,
|
|
)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
# We accumulate along the K dimension.
|
|
if use_int8_w8a16:
|
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
|
elif use_fp8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_scale = tl.load(
|
|
a_scale_ptrs + offs_ks * stride_ask,
|
|
mask=token_mask,
|
|
other=0.0,
|
|
)
|
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
|
|
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
|
else:
|
|
accumulator = tl.dot(a, b, acc=accumulator)
|
|
else:
|
|
accumulator += tl.dot(a, b)
|
|
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if MUL_ROUTED_WEIGHT:
|
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
|
accumulator = accumulator * moe_weight[:, None]
|
|
if use_int8_w8a16:
|
|
accumulator = (accumulator * b_scale).to(compute_type)
|
|
elif use_fp8_w8a8:
|
|
if group_k > 0 and group_n > 0:
|
|
accumulator = accumulator.to(compute_type)
|
|
else:
|
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
else:
|
|
accumulator = accumulator.to(compute_type)
|
|
# Write back the block of the output
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
|
|
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|