mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
【Bug Fix】mla enables tensorcore by default (#4354)
* mla tensor-core kernel is enabled by default
This commit is contained in:
@@ -331,8 +331,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = true; //GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
|
||||
const bool mla_backend = checkAttentionBackend();
|
||||
if (mla_backend && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
@@ -396,28 +397,40 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
|
||||
// should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
|
||||
@@ -27,6 +27,8 @@
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
#include <hip/hip_bfloat16.h>
|
||||
@@ -604,6 +606,18 @@ inline bool GetMlaUseTensorcore() {
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
||||
inline const char *getEnvVar(const char *varName) {
|
||||
return std::getenv(varName);
|
||||
}
|
||||
|
||||
inline bool checkAttentionBackend() {
|
||||
const char *backend = getEnvVar("FD_ATTENTION_BACKEND");
|
||||
if (backend && std::strcmp(backend, "MLA_ATTN") == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
|
||||
Reference in New Issue
Block a user