// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" #include "helper.h" template __global__ void Wint25UnzipKernel(const uint16_t *zipped_weight_ptr, const T *super_scale_ptr, T *weight_ptr, const int64_t batch, const int64_t num_rows, const int64_t num_columns) { using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor; __shared__ T smem[TileRows * TileColumns]; int64_t block_start_column = blockIdx.x * TileColumns; int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows; int64_t block_start_zipped_row = block_start_row * 10 / 64; int64_t block_zipped_offset = block_start_zipped_row * num_columns + block_start_column; const uint16_t *block_zipped_weight_ptr = zipped_weight_ptr + block_zipped_offset; const T *block_super_scale_ptr = super_scale_ptr + blockIdx.z * num_columns + block_start_column; // unzip to shared memory UnzipFunctor unzip_functor; unzip_functor(block_zipped_weight_ptr, block_super_scale_ptr, smem, num_columns); // write back to global memory for (int row = 0; row < TileRows; ++row) { for (int col = 0; col < TileColumns; ++col) { int64_t global_row = block_start_row + row; int64_t global_col = block_start_column + col; weight_ptr[global_row * num_columns + global_col] = smem[row * TileColumns + col]; } } } template __global__ void Wint2UnzipKernel(const uint8_t *zipped_weight_ptr, const uint8_t *local_scale_ptr, const float *code_scale_ptr, const float *code_zp_ptr, const T *super_scale_ptr, T *weight_ptr, const int64_t batch, const int64_t num_rows, const int64_t num_columns) { using UnzipFunctor = cutlass::gemm::threadblock::UnzipAndDequantFunctor; constexpr bool kUseAsyncLoad = true; __shared__ uint8_t zipped_smem[UnzipFunctor::kZippedSmemBytes + UnzipFunctor::kColumnWiseSmemBytes]; __shared__ T smem[TileRows * TileColumns]; int64_t block_start_column = blockIdx.x * TileColumns; int64_t block_start_row = blockIdx.z * num_rows + blockIdx.y * TileRows; int64_t block_start_zipped_row = block_start_row / 4; int64_t block_zipped_offset = block_start_zipped_row * num_columns + block_start_column; const uint8_t *block_zipped_weight_ptr = zipped_weight_ptr + block_zipped_offset; // local_scale is uint4 int64_t block_start_local_scale_row = block_start_row / (64 * 2); int64_t block_local_scale_offset = block_start_local_scale_row * num_columns + block_start_column; const uint8_t *block_local_scale_ptr = local_scale_ptr + block_local_scale_offset; const float *block_code_scale_ptr = code_scale_ptr + blockIdx.z * num_columns + block_start_column; const float *block_code_zp_ptr = code_zp_ptr + blockIdx.z * num_columns + block_start_column; const T *block_super_scale_ptr = super_scale_ptr ? super_scale_ptr + blockIdx.z * num_columns + block_start_column : nullptr; typename UnzipFunctor::Arguments args(zipped_smem, zipped_smem + UnzipFunctor::kZippedSmemBytes); // unzip to shared memory UnzipFunctor functor; if (kUseAsyncLoad) { functor.LoadAsync(block_zipped_weight_ptr, block_local_scale_ptr, block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr, &args, num_columns, true); // 发起 cp.async 的收束 cutlass::arch::cp_async_fence(); // wait for cp.async cutlass::arch::cp_async_wait<0>(); __syncthreads(); } else { functor.Load(block_zipped_weight_ptr, block_local_scale_ptr, block_code_scale_ptr, block_code_zp_ptr, block_super_scale_ptr, &args, num_columns, true); } functor.Compute(args, smem, block_start_row); // write back to global memory for (int row = 0; row < TileRows; ++row) { for (int col = 0; col < TileColumns; ++col) { int64_t global_row = block_start_row + row; int64_t global_col = block_start_column + col; weight_ptr[global_row * num_columns + global_col] = smem[row * TileColumns + col]; } } } template void Wint25UnzipKernelLauncher(const uint16_t *zipped_weight, const T *supper_scale, T *weight, const int64_t batch, const int64_t num_rows, const int64_t num_columns) { constexpr int kTileRows = 64; constexpr int kTileColumns = 128; constexpr int kNumThreads = 128; const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns; const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows; dim3 block_dim(kNumThreads, 1, 1); dim3 grid_dim(block_dim_x, block_dim_y, batch); Wint25UnzipKernel <<>>(zipped_weight, supper_scale, weight, batch, num_rows, num_columns); } template void Wint2UnzipKernelLauncher(const uint8_t *zipped_weight, const uint8_t *local_scale, const float *code_scale, const float *code_zp, const T *supper_scale, T *weight, const int64_t batch, const int64_t num_rows, const int64_t num_columns) { constexpr int kTileRows = 64; constexpr int kTileColumns = 256; constexpr int kNumThreads = 256; const int block_dim_x = (num_columns + kTileColumns - 1) / kTileColumns; const int block_dim_y = (num_rows + kTileRows - 1) / kTileRows; dim3 block_dim(kNumThreads, 1, 1); dim3 grid_dim(block_dim_x, block_dim_y, batch); Wint2UnzipKernel <<>>(zipped_weight, local_scale, code_scale, code_zp, supper_scale, weight, batch, num_rows, num_columns); } template void WintxUnzipKernel(const paddle::Tensor &zipped_weight, const paddle::optional &local_scale, const paddle::optional &code_scale, const paddle::optional &code_zp, const paddle::optional &super_scale, paddle::Tensor &weight, const std::string &quant_method) { using data_t = typename PDTraits::data_t; using NvType = typename PDTraits::DataType; paddle::Tensor *super_scale_tensor = const_cast(super_scale.get_ptr()); const auto *super_scale_ptr = super_scale_tensor ? super_scale_tensor->data() : nullptr; auto *weight_ptr = weight.data(); const int64_t batch = weight.shape()[0]; const int64_t num_rows = weight.shape()[1]; const int64_t num_columns = weight.shape()[2]; if (quant_method == "weight_only_int2.5") { const auto *zipped_weight_ptr = zipped_weight.data(); Wint25UnzipKernelLauncher( reinterpret_cast(zipped_weight_ptr), reinterpret_cast(super_scale_ptr), reinterpret_cast(weight_ptr), batch, num_rows, num_columns); } else if (quant_method == "weight_only_int2") { paddle::Tensor *local_scale_tensor = const_cast(local_scale.get_ptr()); paddle::Tensor *code_scale_tensor = const_cast(code_scale.get_ptr()); paddle::Tensor *code_zp_tensor = const_cast(code_zp.get_ptr()); Wint2UnzipKernelLauncher( zipped_weight.data(), local_scale_tensor->data(), code_scale_tensor->data(), code_zp_tensor->data(), reinterpret_cast(super_scale_ptr), reinterpret_cast(weight_ptr), batch, num_rows, num_columns); } else { PD_THROW("Unsupported quant_method for WintxUnzip."); } } std::vector WintXUnzip(const paddle::Tensor &zipped_weight, const paddle::optional &local_scale, const paddle::optional &code_scale, const paddle::optional &code_zp, const paddle::optional &super_scale, const std::string &quant_method) { paddle::Tensor *local_scale_tensor = const_cast(local_scale.get_ptr()); paddle::Tensor *super_scale_tensor = const_cast(super_scale.get_ptr()); if (quant_method == "weight_only_int2.5") { PD_CHECK(super_scale_tensor, "super_scale must be set in wint2.5!"); } else if (quant_method == "weight_only_int2") { PD_CHECK(local_scale_tensor, "local_scale must be set in wint2.0!"); } auto place = zipped_weight.place(); auto dtype = super_scale_tensor ? super_scale_tensor->dtype() : local_scale_tensor->dtype(); auto output_dims = zipped_weight.dims(); const int unzip_axis = 1; if (quant_method == "weight_only_int2.5") { output_dims[unzip_axis] = output_dims[unzip_axis] / 10 * 64; } else if (quant_method == "weight_only_int2") { output_dims[unzip_axis] = output_dims[unzip_axis] * 4; } else { PD_THROW("Unsupported data type for WintxUnzip"); } auto output_tensor = GetEmptyTensor(output_dims, dtype, place); switch (dtype) { case paddle::DataType::BFLOAT16: WintxUnzipKernel( zipped_weight, local_scale, code_scale, code_zp, super_scale, output_tensor, quant_method); break; case paddle::DataType::FLOAT16: WintxUnzipKernel( zipped_weight, local_scale, code_scale, code_zp, super_scale, output_tensor, quant_method); break; default: PD_THROW("Unsupported data type for WintxUnzip"); } return {output_tensor}; } std::vector> WintXUnzipInferShape( const std::vector &zipped_weight_shape, const paddle::optional> &local_scale_shape, const paddle::optional> &code_scale_shape, const paddle::optional> &code_zp_shape, const paddle::optional> &super_scale_shape, const std::string &quant_method) { std::vector output_shape(zipped_weight_shape); const int unzip_axis = 1; if (quant_method == "weight_only_int2.5") { output_shape[unzip_axis] = zipped_weight_shape[unzip_axis] / 10 * 64; PD_CHECK(output_shape[unzip_axis] % 64 == 0, "unzip_size must be divisible by 64 in wint2.5!"); } else if (quant_method == "weight_only_int2") { output_shape[unzip_axis] = zipped_weight_shape[unzip_axis] * 4; PD_CHECK(output_shape[unzip_axis] % 64 == 0, "unzip_size must be divisible by 64 in wint2!"); } else { PD_THROW("Unsupported quant_type for WintxUnzip"); } return {output_shape}; } std::vector WintXUnzipInferDtype( const paddle::DataType &zipped_weight_dtype, const paddle::optional &local_scale_dtype, const paddle::optional &code_scale_dtype, const paddle::optional &code_zp_dtype, const paddle::optional &super_scale_dtype) { if (super_scale_dtype.is_initialized()) { return {super_scale_dtype.get()}; } else if (local_scale_dtype.is_initialized()) { return {local_scale_dtype.get()}; } else { PD_THROW("Both super_scale and local_scale are not set for WintxUnzip."); } } PD_BUILD_STATIC_OP(winx_unzip) .Inputs({"zipped_weight", paddle::Optional("local_scale"), paddle::Optional("code_scale"), paddle::Optional("code_zp"), paddle::Optional("super_scale")}) .Outputs({"weight"}) .Attrs({"quant_method:std::string"}) .SetKernelFn(PD_KERNEL(WintXUnzip)) .SetInferShapeFn(PD_INFER_SHAPE(WintXUnzipInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(WintXUnzipInferDtype));