mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
517 lines
23 KiB
C++
517 lines
23 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
|
|
void row_major_to_column_major(int8_t* col_major_tensor,
|
|
const int8_t* row_major_tensor,
|
|
const std::vector<size_t>& shape){
|
|
size_t m = shape[0];
|
|
size_t n = shape[1];
|
|
for(auto i=0;i<m*n;i++){
|
|
size_t im = i / n;
|
|
size_t in = i % n;
|
|
col_major_tensor[in*m+im] = row_major_tensor[im*n+in];
|
|
}
|
|
}
|
|
|
|
void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr,
|
|
int64_t num_elts)
|
|
{
|
|
int8_t* int8_tensor = reinterpret_cast<int8_t *>(int8_tensor_ptr);
|
|
for (int ii = 0; ii < num_elts; ++ii) {
|
|
int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
|
|
// int8_tensor[ii] = int8_t(int(int8_tensor[ii]));
|
|
}
|
|
|
|
// Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no
|
|
// performance benefit and is purely so that int4 and int8 have the same layout.
|
|
// Pictorially, this does the following:
|
|
// bit 32 0
|
|
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
|
|
//
|
|
// And it will rearrange the output 32 bit register to be the following:
|
|
// bit 32 0
|
|
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
|
|
|
|
for (int64_t base = 0; base < num_elts; base += 4) {
|
|
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
|
|
}
|
|
}
|
|
|
|
|
|
void subbyte_transpose_impl_int4(int8_t* transposed_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape)
|
|
{
|
|
const int bits_per_elt = 4;
|
|
|
|
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const size_t col_bytes = num_cols * bits_per_elt / 8;
|
|
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
|
const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
|
|
|
|
const uint8_t* input_byte_ptr = reinterpret_cast<const uint8_t*>(quantized_tensor);
|
|
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
|
|
|
|
// static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
|
|
static constexpr int ELTS_PER_BYTE = 2;
|
|
|
|
static constexpr int M_TILE_L1 = 64;
|
|
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
|
|
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
|
|
|
|
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
|
|
|
|
// We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples
|
|
// of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it
|
|
// allows GCC to emit vector instructions.
|
|
|
|
const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
|
|
const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
|
|
|
|
for (size_t expert = 0; expert < num_experts; ++expert) {
|
|
const size_t matrix_offset = expert * num_rows * col_bytes;
|
|
for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) {
|
|
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) {
|
|
|
|
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
|
const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
|
|
|
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
|
const int row = row_tile_start + ii;
|
|
|
|
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
|
const int col = col_tile_start_byte + jj;
|
|
|
|
const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
|
|
|
|
if (row < row_limit && col < col_limit) {
|
|
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
|
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
|
// Using M_TILE_L1 here is deliberate since we assume that the cache tile
|
|
// is square in the number of elements (not necessarily the number of bytes).
|
|
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
|
|
const int ii_byte = ii / ELTS_PER_BYTE;
|
|
const int ii_bit_offset = ii % ELTS_PER_BYTE;
|
|
|
|
const int jj_byte = jj / ELTS_PER_BYTE;
|
|
const int jj_bit_offset = jj % ELTS_PER_BYTE;
|
|
|
|
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
|
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
|
|
|
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
|
|
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
|
|
|
|
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
|
|
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
|
|
}
|
|
}
|
|
|
|
|
|
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
|
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
|
|
|
const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
|
const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
|
|
|
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
|
const int row = row_tile_start_trans + ii;
|
|
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
|
const int col = col_tile_start_byte_trans + jj;
|
|
|
|
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
|
|
|
|
if (row < row_limit_trans && col < col_limit_trans) {
|
|
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
|
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts)
|
|
{
|
|
const int num_bytes = num_elts / 2;
|
|
|
|
// Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
|
|
// instructions as possible in the CUDA code.
|
|
for (size_t ii = 0; ii < num_bytes; ++ii) {
|
|
int8_t transformed_packed_int4s = 0;
|
|
// We don't need to mask in these ops since everything should be in the range 0-15
|
|
int8_t transformed_first_elt = (packed_int4_tensor[ii] & 0x0F);
|
|
int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4);
|
|
|
|
transformed_packed_int4s |= transformed_first_elt;
|
|
transformed_packed_int4s |= (transformed_second_elt << 4);
|
|
packed_int4_tensor[ii] = transformed_packed_int4s;
|
|
}
|
|
|
|
// Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical
|
|
// instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the
|
|
// following: Take as input a 32 bit register with layout: bit 32 0
|
|
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits)
|
|
//
|
|
// And it will rearrange the output 32 bit register to be the following:
|
|
// bit 32 0
|
|
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits)
|
|
|
|
// FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout");
|
|
const size_t num_registers = num_bytes / 4;
|
|
|
|
uint32_t* register_ptr = reinterpret_cast<uint32_t*>(packed_int4_tensor);
|
|
for (size_t ii = 0; ii < num_registers; ++ii) {
|
|
const uint32_t current_register = register_ptr[ii];
|
|
uint32_t transformed_register = 0;
|
|
|
|
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
|
|
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
|
const int src_shift = 4 * src_idx;
|
|
const int dest_shift = 4 * dest_idx;
|
|
|
|
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
|
transformed_register |= (src_bits << dest_shift);
|
|
|
|
}
|
|
register_ptr[ii] = transformed_register;
|
|
}
|
|
}
|
|
|
|
void permute_B_rows_for_mixed_and_int8_gemm(int8_t* permuted_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape,
|
|
const int64_t arch_version)
|
|
{
|
|
|
|
// We only want to run this step for weight only quant.
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const int BITS_PER_ELT = 8;
|
|
const int K = 16 / BITS_PER_ELT;
|
|
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
|
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
|
|
|
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
|
|
|
int MMA_SHAPE_N = 8;
|
|
int B_ROWS_PER_MMA = 8 * K;
|
|
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
|
|
|
const int num_vec_cols = num_cols / elts_in_int32;
|
|
|
|
// The code is written as below so it works for both int8 and packed int4.
|
|
for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
|
|
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
|
|
|
|
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
|
|
const int write_row = base_row + tile_row;
|
|
const int tile_read_row =
|
|
4 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
|
|
|
const int read_row = base_row + tile_read_row;
|
|
const int read_col = write_col;
|
|
|
|
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
|
|
const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col;
|
|
|
|
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures.
|
|
// The data is permuted such that:
|
|
// For int8, each group of 16 rows is permuted using the map below:
|
|
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
|
// 0 1 2 3 4 5 6 7
|
|
template<int bits=8>
|
|
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape,
|
|
const int64_t arch_version)
|
|
{
|
|
|
|
// We only want to run this step for weight only quant.
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const int BITS_PER_ELT = bits;
|
|
const int K = 16 / BITS_PER_ELT;
|
|
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
|
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
|
|
|
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
|
|
|
int MMA_SHAPE_N = 8;
|
|
int B_ROWS_PER_MMA = 8 * K;
|
|
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
|
|
|
const int num_vec_cols = num_cols / elts_in_int32;
|
|
|
|
// The code is written as below so it works for both int8 and packed int4.
|
|
for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
|
|
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
|
|
|
|
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
|
|
const int write_row = base_row + tile_row;
|
|
const int tile_read_row =
|
|
8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
|
if(base_row == 0 && write_col == 0){
|
|
std::cout<<"tile_read_row:"<<tile_read_row<<std::endl;
|
|
}
|
|
const int read_row = base_row + tile_read_row;
|
|
const int read_col = write_col;
|
|
|
|
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
|
|
const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col;
|
|
|
|
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template<int bits=4>
|
|
void permute_B_rows_for_mixed_gemm_int4(int8_t* permuted_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape,
|
|
const int64_t arch_version)
|
|
{
|
|
|
|
// We only want to run this step for weight only quant.
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const int BITS_PER_ELT = bits; //4
|
|
const int K = 16 / BITS_PER_ELT; // 4
|
|
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; // 2
|
|
const int ELTS_PER_REG = 32 / BITS_PER_ELT; // 8
|
|
|
|
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
|
|
|
int MMA_SHAPE_N = 8;
|
|
int B_ROWS_PER_MMA = 8 * K; // 32
|
|
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
|
|
|
const int num_vec_cols = num_cols / elts_in_int32;
|
|
const std::vector<int> tile_col_map{
|
|
0,2,16,18,
|
|
1,3,17,19,
|
|
4,6,20,22,
|
|
5,7,21,23,
|
|
8,10,24,26,
|
|
9,11,25,27,
|
|
12,14,28,30,
|
|
13,15,29,31};
|
|
|
|
// const std::vector<int> tile_col_map{
|
|
// 0 0,2,16,18,
|
|
// 4 1,3,17,19,
|
|
// 8 4,6,20,22,
|
|
// 12 5,7,21,23,
|
|
// 16 8,10,24,26,
|
|
// 20 9,11,25,27,
|
|
// 24 12,14,28,30,
|
|
// 28 13,15,29,31};
|
|
// std::vector<int> tile_col_map(32);
|
|
// for(int i=0;i<32;i++){
|
|
// tile_col_map[i]=i;
|
|
// }
|
|
// // tile_col_map[1]=4;
|
|
// tile_col_map[0]=0;
|
|
// tile_col_map[4]=1;
|
|
// tile_col_map[1]=2;
|
|
// tile_col_map[5]=3;
|
|
// tile_col_map[8]=4;
|
|
// tile_col_map[12]=5;
|
|
// tile_col_map[9]=6;
|
|
// tile_col_map[13]=7;
|
|
// tile_col_map[16]=8;
|
|
// tile_col_map[20]=9;
|
|
// tile_col_map[17]=10;
|
|
// tile_col_map[21]=11;
|
|
// tile_col_map[24]=12;
|
|
// tile_col_map[28]=13;
|
|
// tile_col_map[25]=14;
|
|
// tile_col_map[29]=15;
|
|
|
|
// tile_col_map[4]=1;
|
|
// tile_col_map[4]=1;
|
|
// tile_col_map[4]=2;
|
|
|
|
// The code is written as below so it works for both int8 and packed int4.
|
|
for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
|
|
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
|
|
|
|
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
|
|
const int write_row = base_row + tile_row;
|
|
// const int tile_read_row =
|
|
// 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
|
// const int tile_read_row = std::distance(tile_col_map.begin(), std::find(tile_col_map.begin(),tile_col_map.end(), tile_row));
|
|
const int tile_read_row = tile_col_map[tile_row];
|
|
if(base_row == 0 && write_col == 0){
|
|
std::cout<<" write_row:"<<tile_row<<" tile_read_row:"<<tile_read_row<<std::endl;
|
|
}
|
|
const int read_row = base_row + tile_read_row;
|
|
const int read_col = write_col;
|
|
|
|
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
|
|
const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col;
|
|
|
|
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape)
|
|
{
|
|
|
|
// We only want to run this step for weight only quant.
|
|
std::cout<<"### in interleave_column_major_tensor"<<std::endl;
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const size_t BITS_PER_ELT = 8;
|
|
const size_t elts_in_int32 = 32 / BITS_PER_ELT;
|
|
|
|
const size_t rows_per_tile = 64;
|
|
std::cout<<"running interleave_column_major_tensor"<<std::endl;
|
|
std::cout<<"num_rows:"<<num_rows<<","
|
|
<<"num_cols:"<<num_cols<<","
|
|
<<"BITS_PER_ELT:"<<BITS_PER_ELT<<","
|
|
<<"elts_in_int32:"<<elts_in_int32<<","
|
|
<<"rows_per_tile:"<<rows_per_tile<<std::endl;
|
|
|
|
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
|
|
|
|
|
|
const size_t num_vec_rows = num_rows / elts_in_int32;
|
|
const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
|
const size_t interleave = 2;
|
|
std::cout<<"num_vec_rows:"<<num_vec_rows<<","
|
|
<<"vec_rows_per_tile:"<<vec_rows_per_tile<<","
|
|
<<"interleave:"<<interleave<<std::endl;
|
|
for (int read_col = 0; read_col < num_cols; ++read_col) {
|
|
const size_t write_col = read_col / interleave;
|
|
for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) {
|
|
for (int vec_read_row = base_vec_row;
|
|
vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
|
|
++vec_read_row) {
|
|
const size_t vec_write_row = interleave * base_vec_row
|
|
+ vec_rows_per_tile * (read_col % interleave)
|
|
+ vec_read_row % vec_rows_per_tile;
|
|
|
|
const size_t read_offset = size_t(read_col) * num_vec_rows + vec_read_row;
|
|
const size_t write_offset = size_t(write_col) * num_vec_rows * interleave + vec_write_row;
|
|
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
void interleave_column_major_tensor_int4(int8_t* interleaved_quantized_tensor,
|
|
const int8_t* quantized_tensor,
|
|
const std::vector<size_t>& shape)
|
|
{
|
|
|
|
// We only want to run this step for weight only quant.
|
|
std::cout<<"### in interleave_column_major_tensor"<<std::endl;
|
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
|
|
|
const size_t BITS_PER_ELT = 4;
|
|
const size_t elts_in_int32 = 32 / BITS_PER_ELT;
|
|
|
|
const size_t rows_per_tile = 64;
|
|
std::cout<<"running interleave_column_major_tensor"<<std::endl;
|
|
std::cout<<"num_rows:"<<num_rows<<","
|
|
<<"num_cols:"<<num_cols<<","
|
|
<<"BITS_PER_ELT:"<<BITS_PER_ELT<<","
|
|
<<"elts_in_int32:"<<elts_in_int32<<","
|
|
<<"rows_per_tile:"<<rows_per_tile<<std::endl;
|
|
|
|
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
|
|
|
|
|
|
const size_t num_vec_rows = num_rows / elts_in_int32;
|
|
const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
|
const size_t interleave = 4;
|
|
std::cout<<"num_vec_rows:"<<num_vec_rows<<","
|
|
<<"vec_rows_per_tile:"<<vec_rows_per_tile<<","
|
|
<<"interleave:"<<interleave<<std::endl;
|
|
for (int read_col = 0; read_col < num_cols; ++read_col) {
|
|
const size_t write_col = read_col / interleave;
|
|
for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) {
|
|
for (int vec_read_row = base_vec_row;
|
|
vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
|
|
++vec_read_row) {
|
|
const size_t vec_write_row = interleave * base_vec_row
|
|
+ vec_rows_per_tile * (read_col % interleave)
|
|
+ vec_read_row % vec_rows_per_tile;
|
|
|
|
const size_t read_offset = size_t(read_col) * num_vec_rows + vec_read_row;
|
|
const size_t write_offset = size_t(write_col) * num_vec_rows * interleave + vec_write_row;
|
|
// std::cout<<"read_offset:"<<read_offset
|
|
// <<",write_offset:"<<write_offset<<std::endl;
|
|
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
|
}
|
|
}
|
|
}
|
|
}
|