mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
@@ -11,10 +11,11 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
@@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
@@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
Reference in New Issue
Block a user