mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-11-01 04:12:58 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			191 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			191 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import triton.language as tl
 | |
| 
 | |
| from fastdeploy.model_executor.ops.triton_ops.triton_utils_v2 import (
 | |
|     paddle_use_triton_v2,
 | |
| )
 | |
| 
 | |
| 
 | |
| @paddle_use_triton_v2()
 | |
| def fused_moe_kernel_paddle(
 | |
|     a_ptr,
 | |
|     b_ptr,
 | |
|     c_ptr,
 | |
|     a_scale_ptr,
 | |
|     b_scale_ptr,
 | |
|     topk_weights_ptr,
 | |
|     sorted_token_ids_ptr,
 | |
|     expert_ids_ptr,
 | |
|     num_tokens_post_padded_ptr,
 | |
|     # Matrix dimensions
 | |
|     max_possible_num_post_padded,
 | |
|     num_valid_tokens,
 | |
|     N: tl.constexpr,
 | |
|     K: tl.constexpr,
 | |
|     stride_am: tl.constexpr,
 | |
|     stride_ak: tl.constexpr,
 | |
|     stride_be: tl.constexpr,
 | |
|     stride_bk: tl.constexpr,
 | |
|     stride_bn: tl.constexpr,
 | |
|     stride_cm: tl.constexpr,
 | |
|     stride_cn: tl.constexpr,
 | |
|     stride_asm: tl.constexpr,
 | |
|     stride_ask: tl.constexpr,
 | |
|     stride_bse: tl.constexpr,
 | |
|     stride_bsk: tl.constexpr,
 | |
|     stride_bsn: tl.constexpr,
 | |
|     # Block size for block-wise fp8 quantization
 | |
|     group_n: tl.constexpr,
 | |
|     group_k: tl.constexpr,
 | |
|     # Meta-parameters
 | |
|     BLOCK_SIZE_M: tl.constexpr,
 | |
|     BLOCK_SIZE_N: tl.constexpr,
 | |
|     BLOCK_SIZE_K: tl.constexpr,
 | |
|     GROUP_SIZE_M: tl.constexpr,
 | |
|     MUL_ROUTED_WEIGHT: tl.constexpr,
 | |
|     top_k: tl.constexpr,
 | |
|     compute_type_enum: tl.constexpr,
 | |
|     use_fp8_w8a8: tl.constexpr,
 | |
|     use_int8_w8a16: tl.constexpr,
 | |
|     even_Ks: tl.constexpr,
 | |
| ):
 | |
|     """
 | |
| 
 | |
|     Key Parameters:
 | |
|     - A: The input tensor representing tokens with shape (*, K), where '*' can
 | |
|         be any shape representing batches and K is the feature dimension of
 | |
|         each token.
 | |
|     - B: The stacked MOE weight tensor with shape (E, N, K), where E is
 | |
|         the number of experts, K is the input feature dimension, and N is
 | |
|         the output feature dimension.
 | |
|     - C: The output cache tensor with shape (M, topk, N), where M is the
 | |
|         total number of tokens post padding, topk is the number of times
 | |
|         each token is repeated, and N is the output feature dimension.
 | |
|     - sorted_token_ids: A tensor containing the sorted indices of tokens,
 | |
|         repeated topk times and arranged by the expert index they are
 | |
|         assigned to.
 | |
|     - expert_ids: A tensor containing the indices of the expert for each
 | |
|         block. It determines which expert matrix from B should be used for
 | |
|         each block in A.
 | |
|     This kernel performs the multiplication of a token by its corresponding
 | |
|     expert matrix as determined by `expert_ids`. The sorting of
 | |
|     `sorted_token_ids` by expert index and padding ensures divisibility by
 | |
|     BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
 | |
|     multiplication across different blocks processed by the same expert.
 | |
|     """
 | |
|     pid = tl.program_id(axis=0)
 | |
|     num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
 | |
|     num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
 | |
|     num_pid_in_group = GROUP_SIZE_M * num_pid_n
 | |
|     group_id = pid // num_pid_in_group
 | |
|     first_pid_m = group_id * GROUP_SIZE_M
 | |
|     group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
 | |
|     pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
 | |
|     pid_n = (pid % num_pid_in_group) // group_size_m
 | |
| 
 | |
|     assert compute_type_enum == 1
 | |
|     compute_type = tl.bfloat16
 | |
| 
 | |
|     num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
 | |
|     if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
 | |
|         return
 | |
|     offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
 | |
|     offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
 | |
|     token_mask = offs_token < num_valid_tokens
 | |
| 
 | |
|     offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
 | |
|     offs_k = tl.arange(0, BLOCK_SIZE_K)
 | |
|     a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
 | |
| 
 | |
|     off_experts = tl.load(expert_ids_ptr + pid_m)
 | |
|     b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
 | |
| 
 | |
|     if use_int8_w8a16:
 | |
|         b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
 | |
|         b_scale = tl.load(b_scale_ptrs)
 | |
| 
 | |
|     if use_fp8_w8a8:
 | |
|         if group_k > 0 and group_n > 0:
 | |
|             a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
 | |
|             offs_bsn = offs_bn // group_n
 | |
|             b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
 | |
|         else:
 | |
|             # (Zkk): every expert has one activation scale and weight scale.
 | |
|             a_scale = tl.load(a_scale_ptr + off_experts)
 | |
|             b_scale = tl.load(b_scale_ptr + off_experts)
 | |
| 
 | |
|     accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
 | |
| 
 | |
|     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
 | |
|         if even_Ks:
 | |
|             a = tl.load(
 | |
|                 a_ptrs,
 | |
|                 mask=token_mask[:, None],
 | |
|                 other=0.0,
 | |
|             )
 | |
|             b = tl.load(b_ptrs, cache_modifier=".cv", eviction_policy="evict_first")
 | |
|         else:
 | |
|             a = tl.load(
 | |
|                 a_ptrs,
 | |
|                 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
 | |
|                 other=0.0,
 | |
|             )
 | |
|             b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
 | |
| 
 | |
|         # We accumulate along the K dimension.
 | |
|         if use_int8_w8a16:
 | |
|             accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
 | |
|         elif use_fp8_w8a8:
 | |
|             if group_k > 0 and group_n > 0:
 | |
|                 k_start = k * BLOCK_SIZE_K
 | |
|                 offs_ks = k_start // group_k
 | |
|                 a_scale = tl.load(
 | |
|                     a_scale_ptrs + offs_ks * stride_ask,
 | |
|                     mask=token_mask,
 | |
|                     other=0.0,
 | |
|                 )
 | |
|                 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
 | |
| 
 | |
|                 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
 | |
|             else:
 | |
|                 accumulator = tl.dot(a, b, acc=accumulator)
 | |
|         else:
 | |
|             accumulator += tl.dot(a, b)
 | |
| 
 | |
|         a_ptrs += BLOCK_SIZE_K * stride_ak
 | |
|         b_ptrs += BLOCK_SIZE_K * stride_bk
 | |
| 
 | |
|     if MUL_ROUTED_WEIGHT:
 | |
|         moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
 | |
|         accumulator = accumulator * moe_weight[:, None]
 | |
|     if use_int8_w8a16:
 | |
|         accumulator = (accumulator * b_scale).to(compute_type)
 | |
|     elif use_fp8_w8a8:
 | |
|         if group_k > 0 and group_n > 0:
 | |
|             accumulator = accumulator.to(compute_type)
 | |
|         else:
 | |
|             accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 | |
|     else:
 | |
|         accumulator = accumulator.to(compute_type)
 | |
|     # Write back the block of the output
 | |
|     offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
 | |
|     c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
 | |
|     c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
 | |
| 
 | |
|     tl.store(c_ptrs, accumulator, mask=c_mask)
 | 
