diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index d00e63875..b5e2baf87 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -331,8 +331,9 @@ void GetBlockShapeAndSplitKVBlock( // decoder if (max_dec_len_this_time > 0) { - const bool mla_use_tensorcore = true; //GetMlaUseTensorcore(); - if (mla_use_tensorcore && group_size <= 64) { + + const bool mla_backend = checkAttentionBackend(); + if (mla_backend && group_size <= 64) { const int set_chunk_size = get_mla_dec_chunk_size(bsz); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( @@ -396,28 +397,40 @@ void GetBlockShapeAndSplitKVBlock( chunk_size); } else { - // Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here - const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q; + // Note:(changwenbin)In order to adapt to cudagraph, the maximum value + // should be taken here + const uint32_t decoder_max_tile_size_per_bs_q = + div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + const uint32_t decoder_batch_shape = + bsz * 1024 * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(decoder_batch_ids.data(), + 0, + decoder_batch_shape * sizeof(int32_t), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, + decoder_batch_shape * sizeof(int32_t), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); - split_q_block<<<1, 32, 0, stream>>>( - seq_lens_this_time.data(), - seq_lens_encoder.data(), - decoder_batch_ids.data(), - decoder_tile_ids_per_batch.data(), - decoder_num_blocks_device.data(), - bsz, - decoder_block_shape_q, - group_size); + split_q_block<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_device.data(), + bsz, + decoder_block_shape_q, + group_size); - decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); } } else { PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index eaa5e3f09..6f6554f03 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -27,6 +27,8 @@ #include #include #include +#include +#include #ifdef PADDLE_WITH_HIP #include @@ -604,6 +606,18 @@ inline bool GetMlaUseTensorcore() { return mla_use_tensorcore; } +inline const char *getEnvVar(const char *varName) { + return std::getenv(varName); +} + +inline bool checkAttentionBackend() { + const char *backend = getEnvVar("FD_ATTENTION_BACKEND"); + if (backend && std::strcmp(backend, "MLA_ATTN") == 0) { + return true; + } + return false; +} + __device__ __forceinline__ float warpReduceMax(float value) { value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16)); value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));