mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 11:26:39 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -16,9 +16,10 @@
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import paddle_use_triton_v2
|
||||
|
||||
|
||||
@triton.jit
|
||||
@paddle_use_triton_v2()
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
@@ -31,22 +32,22 @@ def fused_moe_kernel_paddle(
|
||||
num_tokens_post_padded_ptr,
|
||||
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
stride_cm: tl.constexpr,
|
||||
stride_cn: tl.constexpr,
|
||||
stride_asm: tl.constexpr,
|
||||
stride_ask: tl.constexpr,
|
||||
stride_bse: tl.constexpr,
|
||||
stride_bsk: tl.constexpr,
|
||||
stride_bsn: tl.constexpr,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
@@ -87,7 +88,7 @@ def fused_moe_kernel_paddle(
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(num_tokens_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
|
||||
Reference in New Issue
Block a user