mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			317 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			317 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| #include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
 | |
| #include "helper.h"
 | |
| 
 | |
| template <typename T, int TileRows, int TileColumns, int NumThreads>
 | |
| __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr,
 | |
|                                   const T *super_scale_ptr, T *weight_ptr,
 | |
|                                   const int64_t batch, const int64_t num_rows,
 | |
|                                   const int64_t num_columns) {
 | |
|   using UnzipFunctor =
 | |
|       cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt25, TileRows,
 | |
|                              TileColumns, NumThreads>;
 | |
| 
 | |
|   __shared__ T smem[TileRows * TileColumns];
 | |
| 
 | |
|   int64_t block_start_column = blockIdx.x * TileColumns;
 | |
| 
 | |
|   int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows;
 | |
|   int64_t block_start_zipped_row = block_start_row * 10 / 64;
 | |
| 
 | |
|   int64_t block_zipped_offset =
 | |
|       block_start_zipped_row * num_columns + block_start_column;
 | |
|   const uint16_t *block_zipped_weight_ptr =
 | |
|       zipped_weight_ptr + block_zipped_offset;
 | |
| 
 | |
|   const T *block_super_scale_ptr =
 | |
|       super_scale_ptr + blockIdx.z * num_columns + block_start_column;
 | |
| 
 | |
|   // unzip to shared memory
 | |
|   UnzipFunctor unzip_functor;
 | |
|   unzip_functor(block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns);
 | |
| 
 | |
|   // write back to global memory
 | |
|   for (int row = 0; row < TileRows; ++row) {
 | |
|     for (int col = 0; col < TileColumns; ++col) {
 | |
|       int64_t global_row = block_start_row + row;
 | |
|       int64_t global_col = block_start_column + col;
 | |
|       weight_ptr[global_row * num_columns + global_col] =
 | |
|           smem[row * TileColumns + col];
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T, int64_t TileRows, int64_t TileColumns, int NumThreads>
 | |
| __global__ void
 | |
| Wint2UnzipKernel(const uint8_t *zipped_weight_ptr,
 | |
|                  const uint8_t *local_scale_ptr, const float *code_scale_ptr,
 | |
|                  const float *code_zp_ptr, const T *super_scale_ptr,
 | |
|                  T *weight_ptr, const int64_t batch, const int64_t num_rows,
 | |
|                  const int64_t num_columns) {
 | |
|   using UnzipFunctor =
 | |
|       cutlass::gemm::threadblock::UnzipAndDequantFunctor<T, cutlass::WintQuantMethod::kWeightOnlyInt2, TileRows,
 | |
|                              TileColumns, NumThreads>;
 | |
| 
 | |
|   constexpr bool kUseAsyncLoad = true;
 | |
| 
 | |
|   __shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes + UnzipFunctor::kColumnWiseSmemBytes];
 | |
|   __shared__ T smem[TileRows * TileColumns];
 | |
| 
 | |
|   int64_t block_start_column = blockIdx.x * TileColumns;
 | |
|   int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows;
 | |
| 
 | |
|   int64_t block_start_zipped_row = block_start_row / 4;
 | |
|   int64_t block_zipped_offset =
 | |
|       block_start_zipped_row * num_columns + block_start_column;
 | |
|   const uint8_t *block_zipped_weight_ptr =
 | |
|       zipped_weight_ptr + block_zipped_offset;
 | |
| 
 | |
|   // local_scale is uint4
 | |
|   int64_t block_start_local_scale_row = block_start_row / (64 * 2);
 | |
|   int64_t block_local_scale_offset =
 | |
|       block_start_local_scale_row * num_columns + block_start_column;
 | |
|   const uint8_t *block_local_scale_ptr =
 | |
|       local_scale_ptr + block_local_scale_offset;
 | |
| 
 | |
|   const float *block_code_scale_ptr =
 | |
|       code_scale_ptr + blockIdx.z * num_columns + block_start_column;
 | |
|   const float *block_code_zp_ptr =
 | |
|       code_zp_ptr + blockIdx.z * num_columns + block_start_column;
 | |
|   const T *block_super_scale_ptr =
 | |
|       super_scale_ptr
 | |
|           ? super_scale_ptr + blockIdx.z * num_columns + block_start_column
 | |
|           : nullptr;
 | |
| 
 | |
|   typename UnzipFunctor::Arguments args(zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes);
 | |
| 
 | |
|   // unzip to shared memory
 | |
|   UnzipFunctor functor;
 | |
| 
 | |
|   if (kUseAsyncLoad) {
 | |
|     functor.LoadAsync(block_zipped_weight_ptr, block_local_scale_ptr,
 | |
|                       block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
 | |
|                       &args, num_columns, true);
 | |
| 
 | |
|     // 发起 cp.async 的收束
 | |
|     cutlass::arch::cp_async_fence();
 | |
| 
 | |
|     // wait for cp.async
 | |
|     cutlass::arch::cp_async_wait<0>();
 | |
|     __syncthreads();
 | |
|   } else {
 | |
|     functor.Load(block_zipped_weight_ptr, block_local_scale_ptr,
 | |
|                  block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr,
 | |
|                  &args, num_columns, true);
 | |
|   }
 | |
| 
 | |
|   functor.Compute(args, smem, block_start_row);
 | |
| 
 | |
|   // write back to global memory
 | |
|   for (int row = 0; row < TileRows; ++row) {
 | |
|     for (int col = 0; col < TileColumns; ++col) {
 | |
|       int64_t global_row = block_start_row + row;
 | |
|       int64_t global_col = block_start_column + col;
 | |
|       weight_ptr[global_row * num_columns + global_col] =
 | |
|           smem[row * TileColumns + col];
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight,
 | |
|                                const T *supper_scale, T *weight,
 | |
|                                const int64_t batch, const int64_t num_rows,
 | |
|                                const int64_t num_columns) {
 | |
|   constexpr int kTileRows = 64;
 | |
|   constexpr int kTileColumns = 128;
 | |
| 
 | |
|   constexpr int kNumThreads = 128;
 | |
|   const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns;
 | |
|   const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows;
 | |
| 
 | |
|   dim3 block_dim(kNumThreads, 1, 1);
 | |
|   dim3 grid_dim(block_dim_x, block_dim_y, batch);
 | |
| 
 | |
|   Wint25UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
 | |
|       <<<grid_dim, block_dim>>>(zipped_weight, supper_scale, weight, batch,
 | |
|                                 num_rows, num_columns);
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight,
 | |
|                               const uint8_t *local_scale,
 | |
|                               const float *code_scale, const float *code_zp,
 | |
|                               const T *supper_scale, T *weight,
 | |
|                               const int64_t batch, const int64_t num_rows,
 | |
|                               const int64_t num_columns) {
 | |
|   constexpr int kTileRows = 64;
 | |
|   constexpr int kTileColumns = 256;
 | |
| 
 | |
|   constexpr int kNumThreads = 256;
 | |
|   const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns;
 | |
|   const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows;
 | |
| 
 | |
|   dim3 block_dim(kNumThreads, 1, 1);
 | |
|   dim3 grid_dim(block_dim_x, block_dim_y, batch);
 | |
| 
 | |
|   Wint2UnzipKernel<T, kTileRows, kTileColumns, kNumThreads>
 | |
|       <<<grid_dim, block_dim>>>(zipped_weight, local_scale, code_scale, code_zp,
 | |
|                                 supper_scale, weight, batch, num_rows,
 | |
|                                 num_columns);
 | |
| }
 | |
| 
 | |
| template <paddle::DataType T>
 | |
| void WintxUnzipKernel(const paddle::Tensor &zipped_weight,
 | |
|                       const paddle::optional<paddle::Tensor> &local_scale,
 | |
|                       const paddle::optional<paddle::Tensor> &code_scale,
 | |
|                       const paddle::optional<paddle::Tensor> &code_zp,
 | |
|                       const paddle::optional<paddle::Tensor> &super_scale,
 | |
|                       paddle::Tensor &weight, const std::string &quant_method) {
 | |
|   using data_t = typename PDTraits<T>::data_t;
 | |
|   using NvType = typename PDTraits<T>::DataType;
 | |
| 
 | |
|   paddle::Tensor *super_scale_tensor =
 | |
|       const_cast<paddle::Tensor *>(super_scale.get_ptr());
 | |
|   const auto *super_scale_ptr =
 | |
|       super_scale_tensor ? super_scale_tensor->data<data_t>() : nullptr;
 | |
| 
 | |
|   auto *weight_ptr = weight.data<data_t>();
 | |
| 
 | |
|   const int64_t batch = weight.shape()[0];
 | |
|   const int64_t num_rows = weight.shape()[1];
 | |
|   const int64_t num_columns = weight.shape()[2];
 | |
| 
 | |
|   if (quant_method == "weight_only_int2.5") {
 | |
|     const auto *zipped_weight_ptr = zipped_weight.data<int16_t>();
 | |
|     Wint25UnzipKernelLauncher<NvType>(
 | |
|         reinterpret_cast<const uint16_t *>(zipped_weight_ptr),
 | |
|         reinterpret_cast<const NvType *>(super_scale_ptr),
 | |
|         reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
 | |
|   } else if (quant_method == "weight_only_int2") {
 | |
|     paddle::Tensor *local_scale_tensor =
 | |
|         const_cast<paddle::Tensor *>(local_scale.get_ptr());
 | |
|     paddle::Tensor *code_scale_tensor =
 | |
|         const_cast<paddle::Tensor *>(code_scale.get_ptr());
 | |
|     paddle::Tensor *code_zp_tensor =
 | |
|         const_cast<paddle::Tensor *>(code_zp.get_ptr());
 | |
| 
 | |
|     Wint2UnzipKernelLauncher<NvType>(
 | |
|         zipped_weight.data<uint8_t>(), local_scale_tensor->data<uint8_t>(),
 | |
|         code_scale_tensor->data<float>(), code_zp_tensor->data<float>(),
 | |
|         reinterpret_cast<const NvType *>(super_scale_ptr),
 | |
|         reinterpret_cast<NvType *>(weight_ptr), batch, num_rows, num_columns);
 | |
|   } else {
 | |
|     PD_THROW("Unsupported quant_method for WintxUnzip.");
 | |
|   }
 | |
| }
 | |
| 
 | |
| std::vector<paddle::Tensor>
 | |
| WintXUnzip(const paddle::Tensor &zipped_weight,
 | |
|            const paddle::optional<paddle::Tensor> &local_scale,
 | |
|            const paddle::optional<paddle::Tensor> &code_scale,
 | |
|            const paddle::optional<paddle::Tensor> &code_zp,
 | |
|            const paddle::optional<paddle::Tensor> &super_scale,
 | |
|            const std::string &quant_method) {
 | |
|   paddle::Tensor *local_scale_tensor =
 | |
|       const_cast<paddle::Tensor *>(local_scale.get_ptr());
 | |
|   paddle::Tensor *super_scale_tensor =
 | |
|       const_cast<paddle::Tensor *>(super_scale.get_ptr());
 | |
|   if (quant_method == "weight_only_int2.5") {
 | |
|     PD_CHECK(super_scale_tensor, "super_scale must be set in wint2.5!");
 | |
|   } else if (quant_method == "weight_only_int2") {
 | |
|     PD_CHECK(local_scale_tensor, "local_scale must be set in wint2.0!");
 | |
|   }
 | |
| 
 | |
|   auto place = zipped_weight.place();
 | |
|   auto dtype = super_scale_tensor ? super_scale_tensor->dtype()
 | |
|                                   : local_scale_tensor->dtype();
 | |
| 
 | |
|   auto output_dims = zipped_weight.dims();
 | |
|   const int unzip_axis = 1;
 | |
|   if (quant_method == "weight_only_int2.5") {
 | |
|     output_dims[unzip_axis] = output_dims[unzip_axis] / 10 * 64;
 | |
|   } else if (quant_method == "weight_only_int2") {
 | |
|     output_dims[unzip_axis] = output_dims[unzip_axis] * 4;
 | |
|   } else {
 | |
|     PD_THROW("Unsupported data type for WintxUnzip");
 | |
|   }
 | |
|   auto output_tensor = GetEmptyTensor(output_dims, dtype, place);
 | |
| 
 | |
|   switch (dtype) {
 | |
|   case paddle::DataType::BFLOAT16:
 | |
|     WintxUnzipKernel<paddle::DataType::BFLOAT16>(
 | |
|         zipped_weight, local_scale, code_scale, code_zp, super_scale,
 | |
|         output_tensor, quant_method);
 | |
|     break;
 | |
|   case paddle::DataType::FLOAT16:
 | |
|     WintxUnzipKernel<paddle::DataType::FLOAT16>(
 | |
|         zipped_weight, local_scale, code_scale, code_zp, super_scale,
 | |
|         output_tensor, quant_method);
 | |
|     break;
 | |
|   default:
 | |
|     PD_THROW("Unsupported data type for WintxUnzip");
 | |
|   }
 | |
|   return {output_tensor};
 | |
| }
 | |
| 
 | |
| std::vector<std::vector<int64_t>> WintXUnzipInferShape(
 | |
|     const std::vector<int64_t> &zipped_weight_shape,
 | |
|     const paddle::optional<std::vector<int64_t>> &local_scale_shape,
 | |
|     const paddle::optional<std::vector<int64_t>> &code_scale_shape,
 | |
|     const paddle::optional<std::vector<int64_t>> &code_zp_shape,
 | |
|     const paddle::optional<std::vector<int64_t>> &super_scale_shape,
 | |
|     const std::string &quant_method) {
 | |
|   std::vector<int64_t> output_shape(zipped_weight_shape);
 | |
|   const int unzip_axis = 1;
 | |
|   if (quant_method == "weight_only_int2.5") {
 | |
|     output_shape[unzip_axis] = zipped_weight_shape[unzip_axis] / 10 * 64;
 | |
|     PD_CHECK(output_shape[unzip_axis] % 64 == 0,
 | |
|              "unzip_size must be divisible by 64 in wint2.5!");
 | |
|   } else if (quant_method == "weight_only_int2") {
 | |
|     output_shape[unzip_axis] = zipped_weight_shape[unzip_axis] * 4;
 | |
|     PD_CHECK(output_shape[unzip_axis] % 64 == 0,
 | |
|              "unzip_size must be divisible by 64 in wint2!");
 | |
|   } else {
 | |
|     PD_THROW("Unsupported quant_type for WintxUnzip");
 | |
|   }
 | |
|   return {output_shape};
 | |
| }
 | |
| 
 | |
| std::vector<paddle::DataType> WintXUnzipInferDtype(
 | |
|     const paddle::DataType &zipped_weight_dtype,
 | |
|     const paddle::optional<paddle::DataType> &local_scale_dtype,
 | |
|     const paddle::optional<paddle::DataType> &code_scale_dtype,
 | |
|     const paddle::optional<paddle::DataType> &code_zp_dtype,
 | |
|     const paddle::optional<paddle::DataType> &super_scale_dtype) {
 | |
|   if (super_scale_dtype.is_initialized()) {
 | |
|     return {super_scale_dtype.get()};
 | |
|   } else if (local_scale_dtype.is_initialized()) {
 | |
|     return {local_scale_dtype.get()};
 | |
|   } else {
 | |
|     PD_THROW("Both super_scale and local_scale are not set for WintxUnzip.");
 | |
|   }
 | |
| }
 | |
| 
 | |
| PD_BUILD_STATIC_OP(winx_unzip)
 | |
|     .Inputs({"zipped_weight", paddle::Optional("local_scale"),
 | |
|              paddle::Optional("code_scale"), paddle::Optional("code_zp"),
 | |
|              paddle::Optional("super_scale")})
 | |
|     .Outputs({"weight"})
 | |
|     .Attrs({"quant_method:std::string"})
 | |
|     .SetKernelFn(PD_KERNEL(WintXUnzip))
 | |
|     .SetInferShapeFn(PD_INFER_SHAPE(WintXUnzipInferShape))
 | |
|     .SetInferDtypeFn(PD_INFER_DTYPE(WintXUnzipInferDtype));
 | 
