mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			201 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			201 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "helper.h"
 | |
| #include "paddle/extension.h"
 | |
| 
 | |
| #define CEILDIV(a,b) (((a+b-1)/b))
 | |
| 
 | |
| template <typename scalar_t>
 | |
| __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
 | |
|                                                     int32_t* __restrict__ sorted_token_ids,
 | |
|                                                     int32_t* __restrict__ cumsum_buffer,
 | |
|                                                     size_t numel) {
 | |
|   const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
 | |
|   const size_t stride = blockDim.x * gridDim.x;
 | |
| 
 | |
|   for (size_t i = tid; i < numel; i += stride) {
 | |
|     int32_t expert_id = topk_ids[i];
 | |
|     int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
 | |
|     sorted_token_ids[rank_post_pad] = i;
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename scalar_t, int num_experts>
 | |
| __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
 | |
|                                             int32_t* __restrict__ expert_ids,
 | |
|                                             int32_t* __restrict__ total_tokens_post_pad,
 | |
|                                             int32_t GEMM_BLOCK_SIZE_M,
 | |
|                                             size_t numel,
 | |
|                                             int32_t* __restrict__ cumsum_buffer) {
 | |
|   __shared__ int32_t tokens_per_ep[num_experts];
 | |
| 
 | |
|   for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
 | |
|       tokens_per_ep[i] = 0;
 | |
|   }
 | |
| 
 | |
|   __syncthreads();
 | |
| 
 | |
|   for (int i = threadIdx.x; i < numel; i += blockDim.x) {
 | |
|     int expert_id = topk_ids[i];
 | |
|     atomicAdd(&tokens_per_ep[expert_id], 1);
 | |
|   }
 | |
| 
 | |
|   __syncthreads();
 | |
| 
 | |
|   if (threadIdx.x == 0) {
 | |
|     cumsum_buffer[0] = 0;
 | |
|     for (int i = 1; i <= num_experts; ++i) {
 | |
|       int expert_count = tokens_per_ep[i-1];
 | |
|       cumsum_buffer[i] = cumsum_buffer[i - 1] + CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M;
 | |
|     }
 | |
|     *total_tokens_post_pad = cumsum_buffer[num_experts];
 | |
|   }
 | |
| 
 | |
|   __syncthreads();
 | |
| 
 | |
|   if (threadIdx.x < num_experts) {
 | |
|     for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; i += GEMM_BLOCK_SIZE_M) {
 | |
|       expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x;
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| 
 | |
| std::vector<std::vector<int64_t>> tritonmoe_preprocessInferShape(const std::vector<int64_t>& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
 | |
| 
 | |
| 
 | |
|     int topk_ids_numel = topk_ids[0] * topk_ids[1];
 | |
|     int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
 | |
| 
 | |
|     std::vector<int64_t> sorted_ids = {max_num_tokens_padded};
 | |
| 
 | |
|     int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
 | |
|     std::vector<int64_t> expert_ids = {max_num_m_blocks};
 | |
|     std::vector<int64_t> num_tokens_post_pad = {1};
 | |
| 
 | |
|     return {sorted_ids, expert_ids, num_tokens_post_pad};
 | |
| }
 | |
| 
 | |
| std::vector<paddle::DataType> tritonmoe_preprocessIferDtype(const paddle::DataType& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
 | |
|     return {paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32};
 | |
| }
 | |
| 
 | |
| 
 | |
| /*
 | |
| supporse num_experts = 8, GEMM_BLOCK_SIZE_M = 4,
 | |
| topk_ids.shape = [4,4], means=topk=4
 | |
| topk_ids=
 | |
| [7 6 5 4
 | |
| 1 2 3 4
 | |
| 0 1 2 3
 | |
| 0 3 2 1]
 | |
| 
 | |
| Then return value `sorted_ids` is
 | |
| 8,12,16,16
 | |
| 4,9,15,16
 | |
| 5,10,14,16
 | |
| 6,11,13,16
 | |
| 3,7,16,16
 | |
| 2,16,16,16
 | |
| 1,16,16,16
 | |
| 0,16,16,16
 | |
| */
 | |
| 
 | |
| 
 | |
| std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) {
 | |
| 
 | |
|     int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];
 | |
|     int max_num_tokens_padded = topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1);
 | |
| 
 | |
|     auto sorted_ids = paddle::full(
 | |
|         {max_num_tokens_padded},
 | |
|         topk_ids_numel,
 | |
|         paddle::DataType::INT32,
 | |
|         topk_ids.place()
 | |
|     );
 | |
| 
 | |
|     int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M;
 | |
| 
 | |
|     auto expert_ids = paddle::empty(
 | |
|         {max_num_m_blocks}, paddle::DataType::INT32,
 | |
|         topk_ids.place()
 | |
|     );
 | |
| 
 | |
|     auto num_tokens_post_pad = paddle::empty(
 | |
|         {1},
 | |
|         paddle::DataType::INT32,
 | |
|          topk_ids.place()
 | |
|     );
 | |
| 
 | |
|     auto cumsum_buffer = paddle::empty(
 | |
|         {num_experts + 1},
 | |
|         paddle::DataType::INT32,
 | |
|         topk_ids.place()
 | |
|     );
 | |
| 
 | |
|     auto stream = topk_ids.stream();
 | |
|     using scalar_t = int64_t;
 | |
| 
 | |
|     # define run_align_kernel(num_experts) \
 | |
|     auto align_kernel = moe_align_block_size_kernel<scalar_t, num_experts>; \
 | |
|     align_kernel<<<1, 1024, 0, stream>>>( \
 | |
|     topk_ids.data<scalar_t>(),  \
 | |
|     expert_ids.data<int32_t>(), \
 | |
|     num_tokens_post_pad.data<int32_t>(), \
 | |
|     GEMM_BLOCK_SIZE_M,  \
 | |
|     topk_ids_numel, \
 | |
|     cumsum_buffer.data<int32_t>());
 | |
| 
 | |
|     if (num_experts == 8) {
 | |
|       run_align_kernel(8);
 | |
|     } else if (num_experts == 256) {
 | |
|       run_align_kernel(256);
 | |
|     } else if (num_experts == 2) {
 | |
|       run_align_kernel(2);
 | |
|     } else if (num_experts == 64) {
 | |
|       run_align_kernel(64);
 | |
|     } else if (num_experts == 128) {
 | |
|       run_align_kernel(128);
 | |
|     } else if (num_experts == 160) {
 | |
|       run_align_kernel(160);
 | |
|     } else {
 | |
|       PD_THROW("Not support num_experts: %d", num_experts);
 | |
|     }
 | |
| 
 | |
|     const int block_threads = 256;
 | |
|     const int num_blocks = CEILDIV(topk_ids_numel, block_threads);
 | |
|     const int max_blocks = 65535;
 | |
|     const int actual_blocks = std::min(num_blocks, max_blocks);
 | |
| 
 | |
|     auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
 | |
| 
 | |
|     sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data<scalar_t>(),
 | |
|                                                               sorted_ids.data<int32_t>(),
 | |
|                                                               cumsum_buffer.data<int32_t>(),
 | |
|                                                               topk_ids_numel);
 | |
| 
 | |
| 
 | |
| 
 | |
|     return {sorted_ids, expert_ids, num_tokens_post_pad};
 | |
| }
 | |
| 
 | |
| PD_BUILD_STATIC_OP(tritonmoe_preprocess)
 | |
|     .Inputs({"topk_ids"})
 | |
|     .Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"})
 | |
|     .Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"})
 | |
|     .SetKernelFn(PD_KERNEL(tritonmoe_preprocess_kernel))
 | |
|     .SetInferShapeFn(PD_INFER_SHAPE(tritonmoe_preprocessInferShape))
 | |
|     .SetInferDtypeFn(PD_INFER_DTYPE(tritonmoe_preprocessIferDtype));
 | 
