From 3a6883ac1a92f431e390dd33ef3e0fc4c0023c86 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Wed, 22 Oct 2025 17:59:50 +0800 Subject: [PATCH] c++ code format (#4527) --- .clang-format | 2 +- .pre-commit-config.yaml | 13 + custom_ops/cpu_ops/avx_weight_only_fake.cc | 20 +- custom_ops/cpu_ops/get_padding_offset.cc | 110 +- custom_ops/cpu_ops/rebuild_padding.cc | 350 ++-- custom_ops/cpu_ops/set_value_by_flags.cc | 50 +- custom_ops/cpu_ops/simd_sort.cc | 58 +- custom_ops/cpu_ops/simd_sort_fake.cc | 22 +- .../cpu_ops/stop_generation_multi_ends.cc | 66 +- .../cpu_ops/token_penalty_multi_scores.cc | 181 +- custom_ops/cpu_ops/update_inputs.cc | 82 +- custom_ops/cpu_ops/xft_all_layer_fake.cc | 8 +- custom_ops/cpu_ops/xft_greedy_search_fake.cc | 16 +- custom_ops/iluvatar_ops/fused_moe_imp_op.h | 2 +- custom_ops/iluvatar_ops/fused_moe_op.h | 141 +- custom_ops/iluvatar_ops/mixed_fused_attn.cu | 522 ++--- custom_ops/iluvatar_ops/moe_dispatch.cu | 109 +- custom_ops/iluvatar_ops/moe_reduce.cu | 137 +- custom_ops/iluvatar_ops/paged_attn.cu | 583 +++--- custom_ops/iluvatar_ops/prefill_fused_attn.cu | 590 +++--- .../iluvatar_ops/runtime/iluvatar_context.cc | 1 - .../iluvatar_ops/runtime/iluvatar_context.h | 30 +- custom_ops/iluvatar_ops/w8a16_group_gemm.cu | 298 +-- custom_ops/metax_ops/fused_moe.cu | 73 +- custom_ops/metax_ops/fused_moe_imp_op.h | 2 +- custom_ops/metax_ops/fused_moe_op.h | 21 +- custom_ops/metax_ops/mc_fused_moe_helper.h | 696 ++++--- custom_ops/metax_ops/moe_dispatch.cu | 19 +- custom_ops/metax_ops/moe_ffn.cu | 126 +- custom_ops/metax_ops/moe_reduce.cu | 6 +- custom_ops/xpu_ops/src/ops/adjust_batch.cc | 192 +- custom_ops/xpu_ops/src/ops/block_attn.cc | 124 +- .../device/get_context_gm_max_mem_demand.cc | 33 +- .../src/ops/device/get_free_global_memory.cc | 38 +- .../src/ops/device/get_total_global_memory.cc | 38 +- .../src/ops/device/get_used_global_memory.cc | 38 +- .../xpu_ops/src/ops/gather_next_token.cc | 115 +- .../xpu_ops/src/ops/get_img_boundaries.cc | 69 +- custom_ops/xpu_ops/src/ops/get_output.cc | 119 +- .../xpu_ops/src/ops/get_padding_offset.cc | 87 +- .../src/ops/get_token_penalty_multi_scores.cc | 131 +- custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc | 2 +- custom_ops/xpu_ops/src/ops/moe_layer.cc | 418 ++-- .../src/ops/pybind/alloc_cache_pinned.cc | 2 +- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 14 +- .../xpu_ops/src/ops/recover_decode_task.cc | 51 +- .../xpu_ops/src/ops/remote_cache_kv_ipc.cc | 8 +- .../xpu_ops/src/ops/save_with_output_msg.cc | 178 +- .../src/ops/set_value_by_flags_and_idx.cc | 45 +- .../xpu_ops/src/ops/share_external_data.cc | 5 +- custom_ops/xpu_ops/src/ops/step.cc | 177 +- .../src/ops/stop_generation_multi_ends.cc | 41 +- .../src/ops/text_image_gather_scatter.cc | 80 +- .../xpu_ops/src/ops/text_image_index_out.cc | 45 +- custom_ops/xpu_ops/src/ops/update_inputs.cc | 63 +- .../xpu_ops/src/ops/update_inputs_v1.cc | 95 +- custom_ops/xpu_ops/src/ops/utility/debug.h | 0 custom_ops/xpu_ops/src/ops/utility/helper.h | 63 +- .../xpu_ops/src/ops/weight_quantize_xpu.cc | 184 +- custom_ops/xpu_ops/src/ops/xpu_multiprocess.h | 62 +- .../xpu_ops/src/plugin/include/xpu/plugin.h | 262 ++- .../kernel/kunlun3cpp/get_padding_offset.xpu | 4 +- .../mtp_kernel/rebuild_append_padding.xpu | 122 +- .../speculate_update_repeat_times.xpu | 8 +- ...speculate_update_value_by_repeat_times.xpu | 16 +- .../kernel/kunlun3cpp/recover_decode_task.xpu | 38 +- .../src/kernel/kunlun3cpp/remove_padding.xpu | 4 +- .../kunlun3cpp/set_stop_value_multi_ends.xpu | 5 +- .../kunlun3cpp/text_image_gather_scatter.xpu | 326 +-- .../kunlun3cpp/text_image_index_out.xpu | 137 +- .../src/kernel/kunlun3cpp/update_inputs.xpu | 3 +- .../kernel/kunlun3cpp/update_inputs_v1.xpu | 169 +- .../update_value_by_repeat_times.xpu | 69 +- .../plugin/src/wrapper/eb_adjust_batch.cpp | 118 +- .../src/wrapper/eb_gather_next_token.cpp | 130 +- .../src/wrapper/free_and_dispatch_block.cpp | 439 +++-- .../plugin/src/wrapper/get_padding_offset.cpp | 238 ++- .../mtp_wrapper/draft_model_postprocess.cpp | 39 +- .../wrapper/nn_set_stop_value_multi_ends.cpp | 220 ++- .../wrapper/nn_set_value_by_flags_and_idx.cpp | 233 ++- .../wrapper/nn_token_penalty_multi_scores.cpp | 538 +++-- .../src/wrapper/quant2d_per_channel.cpp | 128 +- .../src/plugin/src/wrapper/recover_block.cpp | 334 ++-- .../src/wrapper/recover_decode_task.cpp | 158 +- .../src/wrapper/text_image_gather_scatter.cpp | 154 +- .../src/wrapper/text_image_index_out.cpp | 68 +- .../src/plugin/src/wrapper/update_inputs.cpp | 216 +- .../plugin/src/wrapper/update_inputs_v1.cpp | 244 +-- .../include/kvcache_connection.h | 239 +-- .../kvcache_transfer/include/kvcache_rdma.h | 185 +- .../kvcache_transfer/include/log.h | 161 +- .../kvcache_transfer/include/util.h | 508 ++--- .../src/kvcache_connection.cpp | 1635 ++++++++-------- .../kvcache_transfer/src/kvcache_rdma.cpp | 1744 +++++++++-------- .../kvcache_transfer/src/log.cpp | 371 ++-- .../kvcache_transfer/src/pybind.cpp | 23 +- tools/codestyle/pre_commit.sh | 5 + 97 files changed, 8760 insertions(+), 7382 deletions(-) mode change 100755 => 100644 custom_ops/xpu_ops/src/ops/utility/debug.h diff --git a/.clang-format b/.clang-format index 3bb927623..a4de8e7be 100644 --- a/.clang-format +++ b/.clang-format @@ -16,7 +16,7 @@ --- Language: Cpp BasedOnStyle: Google -IndentWidth: 4 +IndentWidth: 2 TabWidth: 2 ContinuationIndentWidth: 4 AccessModifierOffset: -1 # The private/protected/public has no indent in class diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c0fec84a..0e72fd69f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,7 @@ +exclude: | + (?x)^( + dockerfiles/.+ + )$ default_install_hook_types: - pre-commit - commit-msg @@ -27,6 +31,15 @@ repos: hooks: - id: ruff args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml] +# For C++ files +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: clang-format -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|xpu|kps)$ # # 拼写检查 # - repo: https://github.com/codespell-project/codespell # rev: v2.4.1 diff --git a/custom_ops/cpu_ops/avx_weight_only_fake.cc b/custom_ops/cpu_ops/avx_weight_only_fake.cc index 2150669af..d117e6606 100644 --- a/custom_ops/cpu_ops/avx_weight_only_fake.cc +++ b/custom_ops/cpu_ops/avx_weight_only_fake.cc @@ -19,28 +19,28 @@ std::vector InvokeAvxWeightOnly(const paddle::Tensor &x, const paddle::Tensor &w_bias, const std::string &alog, bool trans) { - auto out_shape = x.shape(); - out_shape[out_shape.size() - 1] = weight.shape()[1]; - auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace()); - return {out}; + auto out_shape = x.shape(); + out_shape[out_shape.size() - 1] = weight.shape()[1]; + auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace()); + return {out}; } std::vector> AvxWeightOnlyInferShape( std::vector x_shape, std::vector weigh_shape, std::vector weigh_bias_shape) { - int m = 1; - for (int i = 0; i < x_shape.size() - 1; i++) { - m = m * x_shape[i]; - } - return {std::vector{m, weigh_shape[1]}}; + int m = 1; + for (int i = 0; i < x_shape.size() - 1; i++) { + m = m * x_shape[i]; + } + return {std::vector{m, weigh_shape[1]}}; } std::vector AvxWeightOnlyInferDtype( paddle::DataType x_dtype, paddle::DataType weight_dtype, paddle::DataType weight_bias_dtype) { - return {x_dtype}; + return {x_dtype}; } PD_BUILD_STATIC_OP(avx_weight_only) diff --git a/custom_ops/cpu_ops/get_padding_offset.cc b/custom_ops/cpu_ops/get_padding_offset.cc index 02ee71a26..50af5a295 100644 --- a/custom_ops/cpu_ops/get_padding_offset.cc +++ b/custom_ops/cpu_ops/get_padding_offset.cc @@ -20,13 +20,13 @@ void remove_padding(int64_t *output_data, const int *cum_offsets, const int sequence_length, const int bsz) { - for (int bi = 0; bi < bsz; ++bi) { - for (int i = 0; i < seq_lens[bi]; ++i) { - const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; - const int src_seq_id = bi * sequence_length + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; - } + for (int bi = 0; bi < bsz; ++bi) { + for (int i = 0; i < seq_lens[bi]; ++i) { + const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; + const int src_seq_id = bi * sequence_length + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; } + } } void get_padding_offset_kernel(int *padding_offset, @@ -37,56 +37,53 @@ void get_padding_offset_kernel(int *padding_offset, const int *seq_lens, const int max_seq_len, const int bsz) { - for (int bi = 0; bi < bsz; ++bi) { - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - auto seq_len_now = seq_lens[bi]; - for (int i = 0; i < seq_len_now; ++i) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; - } - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; + for (int bi = 0; bi < bsz; ++bi) { + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + auto seq_len_now = seq_lens[bi]; + for (int i = 0; i < seq_len_now; ++i) { + padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; } + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } } std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const paddle::Tensor &cum_offsets, const paddle::Tensor &token_num, const paddle::Tensor &seq_len) { - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::empty( - {token_num_data}, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::empty( - {token_num_data}, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - get_padding_offset_kernel(padding_offset.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - remove_padding(x_remove_padding.data(), - input_ids.data(), - seq_len.data(), - cum_offsets_out.data(), - seq_length, - bsz); - return {x_remove_padding, - padding_offset, - cu_seqlens_q, - cu_seqlens_k}; + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + get_padding_offset_kernel(padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + remove_padding(x_remove_padding.data(), + input_ids.data(), + seq_len.data(), + cum_offsets_out.data(), + seq_length, + bsz); + return {x_remove_padding, padding_offset, cu_seqlens_q, cu_seqlens_k}; } std::vector> GetPaddingOffsetInferShape( @@ -94,9 +91,9 @@ std::vector> GetPaddingOffsetInferShape( const std::vector &cum_offsets_shape, const std::vector &token_num_shape, const std::vector &seq_len_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( @@ -104,18 +101,13 @@ std::vector GetPaddingOffsetInferDtype( const paddle::DataType &cum_offsets_dtype, const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; + return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_STATIC_OP(get_padding_offset_cpu) .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) - .Outputs({"x_remove_padding", - "padding_offset", - "cu_seqlens_q", - "cu_seqlens_k"}) + .Outputs( + {"x_remove_padding", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype)); diff --git a/custom_ops/cpu_ops/rebuild_padding.cc b/custom_ops/cpu_ops/rebuild_padding.cc index 2dfc9f17e..9e4627dfb 100644 --- a/custom_ops/cpu_ops/rebuild_padding.cc +++ b/custom_ops/cpu_ops/rebuild_padding.cc @@ -19,7 +19,6 @@ #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif - template void RebuildPaddingCPUImpl(T *output_data, const T *input_data, @@ -30,27 +29,27 @@ void RebuildPaddingCPUImpl(T *output_data, int max_input_length, int dim_embed, const int elem_nums) { - for (int i = 0; i < elem_nums; ++i) { - const int bi = i / dim_embed; - const int bias_idx = i % dim_embed; - int seq_id = 0; + for (int i = 0; i < elem_nums; ++i) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; - if (seq_len_this_time_data[bi] == 0) { - continue; - } - if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { - continue; - } - - if (seq_lens_encoder_data[bi] > 0) { - seq_id = seq_lens_encoder_data[bi] - 1; - } - - const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id; - const int src_offset = ori_token_idx * dim_embed + bias_idx; - - output_data[i] = input_data[src_offset]; + if (seq_len_this_time_data[bi] == 0) { + continue; } + if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { + continue; + } + + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; + } + + const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + + output_data[i] = input_data[src_offset]; + } } template @@ -64,27 +63,25 @@ void RebuildAppendPaddingCPUImpl(T *output_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums) { - for (int i = 0; i < output_elem_nums; ++i) { - int out_token_id = i / dim_embed; - int ori_token_id = - out_token_id + output_padding_offset_data[out_token_id]; - int bi = ori_token_id / max_input_length; - if (seq_len_this_time_data[bi] == 0 || - (seq_lens_decoder_data[bi] == 0 && - seq_lens_encoder_data[bi] == 0)) { - continue; - } - int seq_id = 0; - - if (seq_lens_encoder_data[bi] > 0) { - seq_id = seq_lens_encoder_data[bi] - 1; - } - int input_token_id = cu_seqlens_q_data[bi] + seq_id; - int bias_idx = i % dim_embed; - int src_offset = input_token_id * dim_embed + bias_idx; - - output_data[i] = input_data[src_offset]; + for (int i = 0; i < output_elem_nums; ++i) { + int out_token_id = i / dim_embed; + int ori_token_id = out_token_id + output_padding_offset_data[out_token_id]; + int bi = ori_token_id / max_input_length; + if (seq_len_this_time_data[bi] == 0 || + (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) { + continue; } + int seq_id = 0; + + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; + } + int input_token_id = cu_seqlens_q_data[bi] + seq_id; + int bias_idx = i % dim_embed; + int src_offset = input_token_id * dim_embed + bias_idx; + + output_data[i] = input_data[src_offset]; + } } std::vector RebuildPaddingCPU( @@ -95,140 +92,139 @@ std::vector RebuildPaddingCPU( const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { - auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); - auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true); - auto seq_len_this_time_cpu = - seq_len_this_time.copy_to(paddle::CPUPlace(), true); - auto seq_lens_decoder_cpu = - seq_lens_decoder.copy_to(paddle::CPUPlace(), true); - auto seq_lens_encoder_cpu = - seq_lens_encoder.copy_to(paddle::CPUPlace(), true); - paddle::optional output_padding_offset_cpu; - if (output_padding_offset) { - output_padding_offset_cpu = - output_padding_offset->copy_to(paddle::CPUPlace(), true); + auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); + auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true); + auto seq_len_this_time_cpu = + seq_len_this_time.copy_to(paddle::CPUPlace(), true); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto seq_lens_encoder_cpu = + seq_lens_encoder.copy_to(paddle::CPUPlace(), true); + paddle::optional output_padding_offset_cpu; + if (output_padding_offset) { + output_padding_offset_cpu = + output_padding_offset->copy_to(paddle::CPUPlace(), true); + } + + int token_num = tmp_out_cpu.shape()[0]; + int dim_embed = tmp_out_cpu.shape()[1]; + int bsz = cu_seqlens_q_cpu.shape()[0] - 1; + + paddle::Tensor out; + if (output_padding_offset_cpu) { + int need_delete_token_num = 0; + for (int i = 0; i < bsz; ++i) { + if (seq_lens_encoder_cpu.data()[i] > 0) { + need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; + } } + int output_token_num = token_num - need_delete_token_num; + out = paddle::full({output_token_num, dim_embed}, + 0, + tmp_out_cpu.dtype(), + paddle::CPUPlace()); + } else { + out = paddle::full( + {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); + } - int token_num = tmp_out_cpu.shape()[0]; - int dim_embed = tmp_out_cpu.shape()[1]; - int bsz = cu_seqlens_q_cpu.shape()[0] - 1; + const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data(); + const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); + const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); + int elem_nums = out.numel(); - paddle::Tensor out; - if (output_padding_offset_cpu) { - int need_delete_token_num = 0; - for (int i = 0; i < bsz; ++i) { - if (seq_lens_encoder_cpu.data()[i] > 0) { - need_delete_token_num += - seq_lens_encoder_cpu.data()[i] - 1; - } - } - int output_token_num = token_num - need_delete_token_num; - out = paddle::full({output_token_num, dim_embed}, - 0, - tmp_out_cpu.dtype(), - paddle::CPUPlace()); - } else { - out = paddle::full( - {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); + if (output_padding_offset_cpu) { + const int *output_padding_offset_data = + output_padding_offset_cpu->data(); + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildAppendPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); } - - const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data(); - const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); - const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); - const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); - int elem_nums = out.numel(); - - if (output_padding_offset_cpu) { - const int *output_padding_offset_data = - output_padding_offset_cpu->data(); - switch (tmp_out_cpu.dtype()) { - case paddle::DataType::FLOAT32: - RebuildAppendPaddingCPUImpl(out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::FLOAT16: - RebuildAppendPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::BFLOAT16: - RebuildAppendPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - default: - PD_THROW( - "Unsupported data type for rebuild_padding_cpu. " - "Only float32, float16, and bfloat16 are supported."); - } - } else { - switch (tmp_out_cpu.dtype()) { - case paddle::DataType::FLOAT32: - RebuildPaddingCPUImpl(out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::FLOAT16: - RebuildPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::BFLOAT16: - RebuildPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cu_seqlens_q_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - default: - PD_THROW( - "Unsupported data type for rebuild_padding_cpu. " - "Only float32, float16, and bfloat16 are supported."); - } + } else { + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); } - return {out}; + } + return {out}; } std::vector> RebuildPaddingInferShape( @@ -238,13 +234,13 @@ std::vector> RebuildPaddingInferShape( const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, const paddle::optional> &output_padding_offset_shape) { - int64_t dim_embed = tmp_out_shape[1]; - if (output_padding_offset_shape) { - return {{-1, dim_embed}}; - } else { - int64_t bsz = cu_seqlens_q_shape[0] - 1; - return {{bsz, dim_embed}}; - } + int64_t dim_embed = tmp_out_shape[1]; + if (output_padding_offset_shape) { + return {{-1, dim_embed}}; + } else { + int64_t bsz = cu_seqlens_q_shape[0] - 1; + return {{bsz, dim_embed}}; + } } std::vector RebuildPaddingInferDtype( @@ -254,7 +250,7 @@ std::vector RebuildPaddingInferDtype( const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + return {tmp_out_dtype}; } PD_BUILD_STATIC_OP(rebuild_padding_cpu) diff --git a/custom_ops/cpu_ops/set_value_by_flags.cc b/custom_ops/cpu_ops/set_value_by_flags.cc index c7e64f432..1266afa1e 100644 --- a/custom_ops/cpu_ops/set_value_by_flags.cc +++ b/custom_ops/cpu_ops/set_value_by_flags.cc @@ -15,27 +15,27 @@ #include "paddle/extension.h" void set_value_by_flags_and_idx(const bool *stop_flags, - int64_t *pre_ids_all, - const int64_t *input_ids, - const int *seq_lens_encoder, - const int *seq_lens_decoder, - const int64_t *step_idx, - int bs, - int length, - int length_input_ids) { - for (int bi = 0; bi < bs; bi++) { - if (!stop_flags[bi]) { - const int seq_len_dec = seq_lens_decoder[bi]; - const int seq_len_enc = seq_lens_encoder[bi]; - int64_t *pre_ids_all_now = pre_ids_all + bi * length; - const int64_t *input_ids_now = input_ids + bi * length_input_ids; - if (seq_len_dec == 0) { - pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1]; - } else { - pre_ids_all_now[step_idx[bi]] = input_ids_now[0]; - } - } + int64_t *pre_ids_all, + const int64_t *input_ids, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int length_input_ids) { + for (int bi = 0; bi < bs; bi++) { + if (!stop_flags[bi]) { + const int seq_len_dec = seq_lens_decoder[bi]; + const int seq_len_enc = seq_lens_encoder[bi]; + int64_t *pre_ids_all_now = pre_ids_all + bi * length; + const int64_t *input_ids_now = input_ids + bi * length_input_ids; + if (seq_len_dec == 0) { + pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1]; + } else { + pre_ids_all_now[step_idx[bi]] = input_ids_now[0]; + } } + } } void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, @@ -45,12 +45,12 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags) { - std::vector pre_ids_all_shape = pre_ids_all.shape(); - int bs = seq_lens_this_time.shape()[0]; - int length = pre_ids_all_shape[1]; - int length_input_ids = input_ids.shape()[1]; + std::vector pre_ids_all_shape = pre_ids_all.shape(); + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all_shape[1]; + int length_input_ids = input_ids.shape()[1]; - set_value_by_flags_and_idx(stop_flags.data(), + set_value_by_flags_and_idx(stop_flags.data(), const_cast(pre_ids_all.data()), input_ids.data(), seq_lens_encoder.data(), diff --git a/custom_ops/cpu_ops/simd_sort.cc b/custom_ops/cpu_ops/simd_sort.cc index 581ee4069..857875a41 100644 --- a/custom_ops/cpu_ops/simd_sort.cc +++ b/custom_ops/cpu_ops/simd_sort.cc @@ -21,45 +21,45 @@ void probs_sort(const float *probs, float *ProbsVals, int vocab_size, int bsz) { - float cursum = 0; - std::vector elementsIds(vocab_size); - std::vector elementsProbs(vocab_size); + float cursum = 0; + std::vector elementsIds(vocab_size); + std::vector elementsProbs(vocab_size); #pragma omp parallel for - for (int j = 0; j < vocab_size; j++) { - elementsIds[j] = j; - elementsProbs[j] = probs[j]; - } - x86simdsortStatic::keyvalue_qsort( - elementsProbs.data(), elementsIds.data(), vocab_size, false, true); + for (int j = 0; j < vocab_size; j++) { + elementsIds[j] = j; + elementsProbs[j] = probs[j]; + } + x86simdsortStatic::keyvalue_qsort( + elementsProbs.data(), elementsIds.data(), vocab_size, false, true); #pragma omp parallel for - for (int j = 0; j < vocab_size; ++j) { - ProbsVals[j] = elementsProbs[j]; - ProbsIds[j] = elementsIds[j]; - } + for (int j = 0; j < vocab_size; ++j) { + ProbsVals[j] = elementsProbs[j]; + ProbsIds[j] = elementsIds[j]; + } } std::vector SimdSort(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto sorted_indices = paddle::empty( - {bsz, vocab_size}, paddle::DataType::INT64, probs.place()); - auto sorted_probs = paddle::empty( - {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); - probs_sort(probs.data(), - const_cast(sorted_indices.data()), - const_cast(sorted_probs.data()), - vocab_size, - bsz); - return {sorted_indices, sorted_probs}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto sorted_indices = + paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place()); + auto sorted_probs = paddle::empty( + {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); + probs_sort(probs.data(), + const_cast(sorted_indices.data()), + const_cast(sorted_probs.data()), + vocab_size, + bsz); + return {sorted_indices, sorted_probs}; } std::vector> SimdSortInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - int64_t vocab_size = probs_shape[1]; - return {{bsz, vocab_size}, {bsz, vocab_size}}; + int64_t bsz = probs_shape[0]; + int64_t vocab_size = probs_shape[1]; + return {{bsz, vocab_size}, {bsz, vocab_size}}; } std::vector SimdSortInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; } PD_BUILD_STATIC_OP(simd_sort) .Inputs({"probs"}) diff --git a/custom_ops/cpu_ops/simd_sort_fake.cc b/custom_ops/cpu_ops/simd_sort_fake.cc index 82cb1af1c..514ff1fa9 100644 --- a/custom_ops/cpu_ops/simd_sort_fake.cc +++ b/custom_ops/cpu_ops/simd_sort_fake.cc @@ -16,23 +16,23 @@ #include "paddle/extension.h" std::vector SimdSort(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto sorted_indices = paddle::empty( - {bsz, vocab_size}, paddle::DataType::INT64, probs.place()); - auto sorted_probs = paddle::empty( - {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); - return {sorted_indices, sorted_probs}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto sorted_indices = + paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place()); + auto sorted_probs = paddle::empty( + {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); + return {sorted_indices, sorted_probs}; } std::vector> SimdSortInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - int64_t vocab_size = probs_shape[1]; - return {{bsz, vocab_size}, {bsz, vocab_size}}; + int64_t bsz = probs_shape[0]; + int64_t vocab_size = probs_shape[1]; + return {{bsz, vocab_size}, {bsz, vocab_size}}; } std::vector SimdSortInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; } PD_BUILD_STATIC_OP(simd_sort) .Inputs({"probs"}) diff --git a/custom_ops/cpu_ops/stop_generation_multi_ends.cc b/custom_ops/cpu_ops/stop_generation_multi_ends.cc index 37f1f40c2..cd4c9323a 100644 --- a/custom_ops/cpu_ops/stop_generation_multi_ends.cc +++ b/custom_ops/cpu_ops/stop_generation_multi_ends.cc @@ -23,13 +23,13 @@ #endif bool is_in_end(const int64_t id, const int64_t *end_ids, int length) { - bool flag = false; - for (int i = 0; i < length; i++) { - if (id == end_ids[i]) { - return true; - } + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; } - return flag; + } + return flag; } void set_value_by_flags(bool *stop_flags, @@ -40,23 +40,23 @@ void set_value_by_flags(bool *stop_flags, const int bs, const int end_length, bool beam_search) { - for (int bi = 0; bi < bs; bi++) { - if (stop_flags[bi]) { - if ((seq_lens[bi] == 0)) { - topk_ids[bi] = -1; - } else { - topk_ids[bi] = end_ids[0]; - next_tokens[bi] = end_ids[0]; - } - } else { - next_tokens[bi] = topk_ids[bi]; - } - if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) { - stop_flags[bi] = true; - topk_ids[bi] = end_ids[0]; - next_tokens[bi] = end_ids[0]; - } + for (int bi = 0; bi < bs; bi++) { + if (stop_flags[bi]) { + if ((seq_lens[bi] == 0)) { + topk_ids[bi] = -1; + } else { + topk_ids[bi] = end_ids[0]; + next_tokens[bi] = end_ids[0]; + } + } else { + next_tokens[bi] = topk_ids[bi]; } + if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) { + stop_flags[bi] = true; + topk_ids[bi] = end_ids[0]; + next_tokens[bi] = end_ids[0]; + } + } } void GetStopFlagsMulti(const paddle::Tensor &topk_ids, @@ -65,17 +65,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &end_ids, const paddle::Tensor &next_tokens, const bool beam_search) { - std::vector shape = topk_ids.shape(); - int64_t bs_now = shape[0]; - int64_t end_length = end_ids.shape()[0]; - set_value_by_flags(const_cast(stop_flags.data()), - const_cast(topk_ids.data()), - const_cast(next_tokens.data()), - end_ids.data(), - seq_lens.data(), - bs_now, - end_length, - false); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + set_value_by_flags(const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, + end_length, + false); } PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu) diff --git a/custom_ops/cpu_ops/token_penalty_multi_scores.cc b/custom_ops/cpu_ops/token_penalty_multi_scores.cc index fdcd56eb6..81b0bed19 100644 --- a/custom_ops/cpu_ops/token_penalty_multi_scores.cc +++ b/custom_ops/cpu_ops/token_penalty_multi_scores.cc @@ -23,16 +23,16 @@ void min_length_logits_process(float *logits, const int64_t bs, const int64_t length, const int64_t end_length) { - for (int bi = 0; bi < bs; ++bi) { - if (cur_len[bi] < 0) { - continue; - } - if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < end_length; ++i) { - logits[bi * length + eos_token_id[i]] = -1e10; - } - } + for (int bi = 0; bi < bs; ++bi) { + if (cur_len[bi] < 0) { + continue; } + if (cur_len[bi] < min_len[bi]) { + for (int i = 0; i < end_length; ++i) { + logits[bi * length + eos_token_id[i]] = -1e10; + } + } + } } void update_repeat_times(const int64_t *pre_ids, @@ -41,20 +41,20 @@ void update_repeat_times(const int64_t *pre_ids, const int64_t bs, const int64_t length, const int64_t length_id) { - for (int bi = 0; bi < bs; ++bi) { - if (cur_len[bi] < 0) { - continue; - } - const int64_t *pre_ids_now = pre_ids + bi * length_id; - int *repeat_times_now = repeat_times + bi * length; - for (int i = 0; i < length_id; i++) { - int64_t id = pre_ids_now[i]; - if (id < 0) { - break; - } - repeat_times_now[id] += 1; - } + for (int bi = 0; bi < bs; ++bi) { + if (cur_len[bi] < 0) { + continue; } + const int64_t *pre_ids_now = pre_ids + bi * length_id; + int *repeat_times_now = repeat_times + bi * length; + for (int i = 0; i < length_id; i++) { + int64_t id = pre_ids_now[i]; + if (id < 0) { + break; + } + repeat_times_now[id] += 1; + } + } } void update_value_by_repeat_times(const int *repeat_times, @@ -65,24 +65,22 @@ void update_value_by_repeat_times(const int *repeat_times, float *logits, const int64_t bs, const int64_t length) { - for (int bi = 0; bi < bs; ++bi) { - float *logits_now = logits + bi * length; - const int *repeat_times_now = repeat_times + bi * length; - float alpha = static_cast(penalty_scores[bi]); - float beta = static_cast(frequency_score[bi]); - float gamma = static_cast(presence_score[bi]); - for (int i = 0; i < length; ++i) { - int times = repeat_times_now[i]; - float logit_now = static_cast(logits_now[i]); - if (times == 0) { - logits_now[i] = - static_cast(logit_now / temperatures[bi]); - } - logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; - logits_now[i] = - static_cast(logit_now - times * beta - gamma); - } + for (int bi = 0; bi < bs; ++bi) { + float *logits_now = logits + bi * length; + const int *repeat_times_now = repeat_times + bi * length; + float alpha = static_cast(penalty_scores[bi]); + float beta = static_cast(frequency_score[bi]); + float gamma = static_cast(presence_score[bi]); + for (int i = 0; i < length; ++i) { + int times = repeat_times_now[i]; + float logit_now = static_cast(logits_now[i]); + if (times == 0) { + logits_now[i] = static_cast(logit_now / temperatures[bi]); + } + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logits_now[i] = static_cast(logit_now - times * beta - gamma); } + } } void ban_bad_words(float *logits, @@ -90,15 +88,14 @@ void ban_bad_words(float *logits, const int64_t bs, const int64_t length, const int64_t bad_words_length) { - for (int bi = 0; bi < bs; ++bi) { - float *logits_now = logits + bi * length; - for (int bwid = 0; bwid < bad_words_length; ++bwid) { - const int64_t bad_words_token_id = bad_words_list[bwid]; - if (bad_words_token_id >= length || bad_words_token_id < 0) - continue; - logits_now[bad_words_token_id] = -1e10; - } + for (int bi = 0; bi < bs; ++bi) { + float *logits_now = logits + bi * length; + for (int bwid = 0; bwid < bad_words_length; ++bwid) { + const int64_t bad_words_token_id = bad_words_list[bwid]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; } + } } template @@ -112,44 +109,44 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - std::vector shape = logits.shape(); - auto repeat_times = - paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); - int64_t bs = shape[0]; - int64_t length = shape[1]; - int64_t length_id = pre_ids.shape()[1]; - int64_t end_length = eos_token_id.shape()[0]; - int64_t length_bad_words = bad_tokens.shape()[0]; + std::vector shape = logits.shape(); + auto repeat_times = + paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + int64_t bs = shape[0]; + int64_t length = shape[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t end_length = eos_token_id.shape()[0]; + int64_t length_bad_words = bad_tokens.shape()[0]; - min_length_logits_process(const_cast(logits.data()), - cur_len.data(), - min_len.data(), - eos_token_id.data(), - bs, - length, - end_length); + min_length_logits_process(const_cast(logits.data()), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bs, + length, + end_length); - update_repeat_times(pre_ids.data(), - cur_len.data(), - repeat_times.data(), - bs, - length, - length_id); + update_repeat_times(pre_ids.data(), + cur_len.data(), + repeat_times.data(), + bs, + length, + length_id); - update_value_by_repeat_times(repeat_times.data(), - penalty_scores.data(), - frequency_score.data(), - presence_score.data(), - temperatures.data(), - const_cast(logits.data()), - bs, - length); + update_value_by_repeat_times(repeat_times.data(), + penalty_scores.data(), + frequency_score.data(), + presence_score.data(), + temperatures.data(), + const_cast(logits.data()), + bs, + length); - ban_bad_words(const_cast(logits.data()), - bad_tokens.data(), - bs, - length, - length_bad_words); + ban_bad_words(const_cast(logits.data()), + bad_tokens.data(), + bs, + length, + length_bad_words); } void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, @@ -162,17 +159,17 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - return token_penalty_multi_scores_kernel( - pre_ids, - logits, - penalty_scores, - frequency_scores, - presence_scores, - temperatures, - bad_tokens, - cur_len, - min_len, - eos_token_id); + return token_penalty_multi_scores_kernel( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id); } PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu) diff --git a/custom_ops/cpu_ops/update_inputs.cc b/custom_ops/cpu_ops/update_inputs.cc index 5985d737e..c2b748002 100644 --- a/custom_ops/cpu_ops/update_inputs.cc +++ b/custom_ops/cpu_ops/update_inputs.cc @@ -24,50 +24,50 @@ void update_inputs_kernel(bool *not_need_stop, const int64_t *next_tokens, const int bsz, const int input_ids_stride) { - int64_t stop_sum = 0; - for (int bi = 0; bi < bsz; ++bi) { - bool stop_flag_now = false; - int64_t stop_flag_now_int = 0; - stop_flag_now = stop_flags[bi]; - stop_flag_now_int = static_cast(stop_flag_now); - auto seq_len_this_time = seq_lens_this_time[bi]; - auto seq_len_encoder = seq_lens_encoder[bi]; - auto seq_len_decoder = seq_lens_decoder[bi]; - seq_lens_decoder[bi] = - stop_flag_now ? 0 - : (seq_len_decoder == 0 ? seq_len_encoder - : seq_len_decoder + 1); - seq_lens_this_time[bi] = stop_flag_now ? 0 : 1; - seq_lens_encoder[bi] = 0; - int64_t *input_ids_now = input_ids + bi * input_ids_stride; - input_ids_now[0] = next_tokens[bi]; - stop_sum += stop_flag_now_int; - } - not_need_stop[0] = stop_sum < stop_nums[0]; + int64_t stop_sum = 0; + for (int bi = 0; bi < bsz; ++bi) { + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + stop_flag_now = stop_flags[bi]; + stop_flag_now_int = static_cast(stop_flag_now); + auto seq_len_this_time = seq_lens_this_time[bi]; + auto seq_len_encoder = seq_lens_encoder[bi]; + auto seq_len_decoder = seq_lens_decoder[bi]; + seq_lens_decoder[bi] = + stop_flag_now + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + seq_lens_this_time[bi] = stop_flag_now ? 0 : 1; + seq_lens_encoder[bi] = 0; + int64_t *input_ids_now = input_ids + bi * input_ids_stride; + input_ids_now[0] = next_tokens[bi]; + stop_sum += stop_flag_now_int; + } + not_need_stop[0] = stop_sum < stop_nums[0]; } void UpdateInputs(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step) { - const int bsz = input_ids.shape()[0]; - const int input_ids_stride = input_ids.shape()[1]; - update_inputs_kernel(const_cast(not_need_stop.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), - stop_nums.data(), - stop_flags.data(), - is_block_step.data(), - next_tokens.data(), - bsz, - input_ids_stride); + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &input_ids, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step) { + const int bsz = input_ids.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + update_inputs_kernel(const_cast(not_need_stop.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + bsz, + input_ids_stride); } PD_BUILD_STATIC_OP(update_inputs_cpu) diff --git a/custom_ops/cpu_ops/xft_all_layer_fake.cc b/custom_ops/cpu_ops/xft_all_layer_fake.cc index aeb20004e..ab64d80c8 100644 --- a/custom_ops/cpu_ops/xft_all_layer_fake.cc +++ b/custom_ops/cpu_ops/xft_all_layer_fake.cc @@ -45,18 +45,18 @@ std::vector InvokeAllLLaMALayer( int maxPositions, int maxPosEmbed, int intermediateSize) { - auto out = paddle::empty_like(input); - return {out}; + auto out = paddle::empty_like(input); + return {out}; } std::vector> AllLLaMALayerInferShape( std::vector x_shape) { - return {x_shape}; + return {x_shape}; } std::vector AllLLaMALayerInferDtype( paddle::DataType x_dtype) { - return {x_dtype}; + return {x_dtype}; } PD_BUILD_STATIC_OP(xft_llama_all_layer) diff --git a/custom_ops/cpu_ops/xft_greedy_search_fake.cc b/custom_ops/cpu_ops/xft_greedy_search_fake.cc index ecf57a2ab..060e7da82 100644 --- a/custom_ops/cpu_ops/xft_greedy_search_fake.cc +++ b/custom_ops/cpu_ops/xft_greedy_search_fake.cc @@ -16,20 +16,20 @@ #include "paddle/extension.h" std::vector XftGreedySearch(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto next_tokens = - paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place()); - return {next_tokens}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto next_tokens = + paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place()); + return {next_tokens}; } std::vector> XftGreedySearchInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - return {{bsz, 1}}; + int64_t bsz = probs_shape[0]; + return {{bsz, 1}}; } std::vector XftGreedySearchInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_STATIC_OP(xft_greedy_search) .Inputs({"probs"}) diff --git a/custom_ops/iluvatar_ops/fused_moe_imp_op.h b/custom_ops/iluvatar_ops/fused_moe_imp_op.h index 254f80e67..3108df789 100644 --- a/custom_ops/iluvatar_ops/fused_moe_imp_op.h +++ b/custom_ops/iluvatar_ops/fused_moe_imp_op.h @@ -16,8 +16,8 @@ */ #pragma once -#include #include +#include #include "cub/cub.cuh" namespace phi { diff --git a/custom_ops/iluvatar_ops/fused_moe_op.h b/custom_ops/iluvatar_ops/fused_moe_op.h index 91bd589f7..8f4fb80a8 100644 --- a/custom_ops/iluvatar_ops/fused_moe_op.h +++ b/custom_ops/iluvatar_ops/fused_moe_op.h @@ -19,8 +19,8 @@ #include #include -#include "fused_moe_imp_op.h" #include "fused_moe_helper.h" +#include "fused_moe_imp_op.h" // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -34,8 +34,8 @@ namespace phi { struct GpuLaunchConfig { - dim3 block_per_grid; - dim3 thread_per_block; + dim3 block_per_grid; + dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { @@ -81,7 +81,6 @@ __launch_bounds__(TPB) __global__ cub::Sum sum; float threadData(-FLT_MAX); - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData = max(static_cast(input[idx]), threadData); @@ -275,7 +274,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; - inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] + : inputs_after_softmax[idx]; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const IdxT prior_winning_expert = indices[k * block_row + prior_k]; @@ -292,7 +292,9 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; - output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + output[idx] = + bias ? inputs_after_softmax[thread_read_offset + result_kvp.key] + : result_kvp.value; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; } @@ -301,14 +303,15 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, } template -__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, - const T* bias, - T* output, - IdxT* indices, - int* source_rows, - const int64_t num_experts, - const int64_t k, - const int64_t num_rows) { +__launch_bounds__(TPB) __global__ + void moe_softmax_top_k_fused(const T* input, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { // softmax using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -321,11 +324,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, return; } const int64_t thread_row_offset = globalIdx * num_experts; - const int64_t idx = thread_row_offset+threadIdx.x; + const int64_t idx = thread_row_offset + threadIdx.x; cub::Sum sum; - float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); + float threadData = + (threadIdx.x < num_experts) ? static_cast(input[idx]) : (-FLT_MAX); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { @@ -377,7 +381,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int cur_idx = k * globalIdx + k_idx; - output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + output[cur_idx] = + bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; indices[cur_idx] = result_kvp.key; source_rows[cur_idx] = k_idx * num_rows + globalIdx; } @@ -386,14 +391,15 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, } template -__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax, - const T* bias, - T* output, - IdxT* indices, - int* source_rows, - const int64_t num_experts, - const int64_t k, - const int64_t num_rows) { +__launch_bounds__(TPB) __global__ + void moe_top_k_normed(const T* inputs_after_softmax, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -422,7 +428,8 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; - inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; + inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] + : inputs_after_softmax[idx]; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const int prior_winning_expert = indices[k * block_row + prior_k]; @@ -439,11 +446,14 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; - // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + // output[idx] = bias ? inputs_after_softmax[thread_read_offset + + // result_kvp.key]: result_kvp.value; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; - T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; + T row_out = + bias ? inputs_after_softmax[thread_read_offset + result_kvp.key] + : result_kvp.value; row_outputs[k_idx] = row_out; weight_sum += row_out; } @@ -458,16 +468,16 @@ __launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_so } } - template -__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input, - const T* bias, - T* output, - IdxT* indices, - int* source_rows, - const int64_t num_experts, - const int64_t k, - const int64_t num_rows) { +__launch_bounds__(TPB) __global__ + void moe_softmax_top_k_normed_fused(const T* input, + const T* bias, + T* output, + IdxT* indices, + int* source_rows, + const int64_t num_experts, + const int64_t k, + const int64_t num_rows) { // softmax using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -480,11 +490,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i return; } const int64_t thread_row_offset = globalIdx * num_experts; - const int64_t idx = thread_row_offset+threadIdx.x; + const int64_t idx = thread_row_offset + threadIdx.x; cub::Sum sum; - float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); + float threadData = + (threadIdx.x < num_experts) ? static_cast(input[idx]) : (-FLT_MAX); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { @@ -542,7 +553,8 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i if (threadIdx.x == 0) { const int cur_idx = k * globalIdx + k_idx; - T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; + T row_out = + bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; row_outputs[k_idx] = row_out; weight_sum += row_out; @@ -595,29 +607,36 @@ void topk_gating_softmax_kernelLauncher(const T* input, if (topk_only_mode) { static constexpr int TPB = 256; const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); - moe_top_k<<>>( - input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows); + moe_top_k + <<>>(input, + gating_correction_bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); return; } static constexpr int WARPS_PER_TB = 4; - #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ +#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ case N: { \ topk_gating_softmax_launcher_helper( \ input, output, indices, source_row, num_rows, num_experts, k, stream); \ break; \ } int64_t tem_num_experts = num_experts; - if(gating_correction_bias != nullptr) tem_num_experts = 0; + if (gating_correction_bias != nullptr) tem_num_experts = 0; switch (tem_num_experts) { - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128) - //LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128) + // LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256) default: { static constexpr int TPB = 256; @@ -646,15 +665,15 @@ void topk_gating_softmax_kernelLauncher(const T* input, const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); moe_softmax<<>>( input, softmax, num_experts, num_rows); - moe_top_k - <<>>(softmax, - gating_correction_bias, - output, - indices, - source_row, - num_experts, - k, - num_rows); + moe_top_k<<>>( + softmax, + gating_correction_bias, + output, + indices, + source_row, + num_experts, + k, + num_rows); } } } diff --git a/custom_ops/iluvatar_ops/mixed_fused_attn.cu b/custom_ops/iluvatar_ops/mixed_fused_attn.cu index 01fb37332..b5388cc2c 100644 --- a/custom_ops/iluvatar_ops/mixed_fused_attn.cu +++ b/custom_ops/iluvatar_ops/mixed_fused_attn.cu @@ -23,8 +23,8 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv, const paddle::Tensor& decode_block_table, const paddle::Tensor& cu_seqlens_qkv, const paddle::Tensor& seq_lens, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, int prefill_num_tokens, int num_heads, int head_dim, @@ -42,318 +42,354 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv, bool enable_cuda_graph, bool use_sqrt_alibi, paddle::Tensor& out) { + typedef PDTraits traits_; + typedef typename traits_::data_t data_t; - typedef PDTraits traits_; - typedef typename traits_::data_t data_t; + const auto& dtype = qkv.dtype(); + cuinferDataType_t cuinfer_data_type; + cudaDataType_t cu_data_type; + if (dtype == paddle::DataType::FLOAT16) { + cuinfer_data_type = CUINFER_DATA_HALF; + cu_data_type = CUDA_R_16F; + } else { + cuinfer_data_type = CUINFER_DATA_BFLOAT16; + cu_data_type = CUDA_R_16BF; + } - const auto& dtype = qkv.dtype(); - cuinferDataType_t cuinfer_data_type; - cudaDataType_t cu_data_type; - if (dtype == paddle::DataType::FLOAT16) { - cuinfer_data_type = CUINFER_DATA_HALF; - cu_data_type = CUDA_R_16F; - } else { - cuinfer_data_type = CUINFER_DATA_BFLOAT16; - cu_data_type = CUDA_R_16BF; - } + const auto& qkv_dims = qkv.dims(); + const auto& kv_cache_dims = k_cache.dims(); + const auto& prefill_block_table_dims = prefill_block_table.dims(); + const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims(); - const auto& qkv_dims = qkv.dims(); - const auto& kv_cache_dims = k_cache.dims(); - const auto& prefill_block_table_dims = prefill_block_table.dims(); - const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims(); + int prefill_batch_size = prefill_block_table_dims[0]; + int num_tokens = qkv_dims[0]; + int decode_num_tokens = num_tokens - prefill_num_tokens; + int num_total_heads = num_heads + 2 * num_kv_heads; + int max_num_blocks_per_seq = prefill_block_table_dims[1]; + int qkv_stride = qkv.strides()[0]; + int num_blocks = kv_cache_dims[0]; - int prefill_batch_size = prefill_block_table_dims[0]; - int num_tokens = qkv_dims[0]; - int decode_num_tokens = num_tokens - prefill_num_tokens; - int num_total_heads = num_heads + 2 * num_kv_heads; - int max_num_blocks_per_seq = prefill_block_table_dims[1]; - int qkv_stride = qkv.strides()[0]; - int num_blocks = kv_cache_dims[0]; + int kv_block_stride = k_cache.strides()[0]; + int kv_head_stride = k_cache.strides()[1]; + int block_table_stride = prefill_block_table.strides()[0]; + const float* rope_sin_ptr = rope_sin ? rope_sin.get().data() : nullptr; + const float* rope_cos_ptr = rope_cos ? rope_cos.get().data() : nullptr; - int kv_block_stride = k_cache.strides()[0]; - int kv_head_stride = k_cache.strides()[1]; - int block_table_stride = prefill_block_table.strides()[0]; - const float *rope_sin_ptr = rope_sin ? rope_sin.get().data() : nullptr; - const float *rope_cos_ptr = rope_cos ? rope_cos.get().data() : nullptr; - - cuinferTensorDescriptor_t qkv_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t qkv_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( qkv_desc, cuinfer_data_type, 3, std::vector({prefill_num_tokens, num_total_heads, head_dim}).data(), std::vector({num_total_heads * head_dim, head_dim, 1}).data())); - cuinferTensorDescriptor_t qkv_seqlens_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t qkv_seqlens_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( qkv_seqlens_desc, CUINFER_DATA_INT32, 1, std::vector({prefill_batch_size + 1}).data(), std::vector({1}).data())); - cuinferTensorDescriptor_t block_table_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t block_table_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( block_table_desc, CUINFER_DATA_INT32, 2, std::vector({prefill_batch_size, block_table_stride}).data(), std::vector({block_table_stride, 1}).data())); - cuinferTensorDescriptor_t o_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t o_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( o_desc, cuinfer_data_type, 3, std::vector({prefill_num_tokens, num_heads, head_dim}).data(), std::vector({num_heads * head_dim, head_dim, 1}).data())); - cuinferTensorDescriptor_t k_cache_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t k_cache_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( k_cache_desc, cuinfer_data_type, 4, std::vector({num_blocks, num_kv_heads, block_size, head_dim}).data(), - std::vector({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data())); + std::vector({num_kv_heads * block_size * head_dim, + block_size * head_dim, + head_dim, + 1}) + .data())); - cuinferTensorDescriptor_t v_cache_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t v_cache_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( v_cache_desc, cuinfer_data_type, 4, std::vector({num_blocks, num_kv_heads, block_size, head_dim}).data(), - std::vector({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data())); + std::vector({num_kv_heads * block_size * head_dim, + block_size * head_dim, + head_dim, + 1}) + .data())); - cuinferTensorDescriptor_t cos_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t cos_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( cos_desc, CUINFER_DATA_FLOAT, 2, std::vector({max_seq_len, head_dim}).data(), std::vector({head_dim, 1}).data())); - cuinferTensorDescriptor_t sin_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t sin_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( sin_desc, CUINFER_DATA_FLOAT, 2, std::vector({max_seq_len, head_dim}).data(), std::vector({head_dim, 1}).data())); - cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle(); + cuinferHandle_t cuinfer_handle = + iluvatar::getContextInstance()->getIxInferHandle(); - size_t prefill_workspace_size = 0; - CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens, - num_heads, - num_kv_heads, - head_dim, - q_rope, - k_rope, - v_rope, - cuinfer_data_type, - cuinfer_data_type, - cuinfer_data_type, - &prefill_workspace_size)); - - auto* allocator = paddle::GetAllocator(qkv.place()); - - phi::Allocator::AllocationPtr prefill_tmp_workspace = allocator->Allocate(prefill_workspace_size); - void* prefill_workspace_ptr = prefill_tmp_workspace->ptr(); - - CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle, - qkv_desc, - qkv.data(), - qkv_seqlens_desc, - cu_seqlens_qkv.data(), - block_table_desc, - prefill_block_table.data(), - o_desc, - out.data(), - k_cache_desc, - k_cache.data(), - v_cache_desc, - v_cache.data(), - prefill_workspace_ptr, - prefill_workspace_size, - cos_desc, - rope_cos_ptr, - sin_desc, - rope_sin_ptr, - prefill_batch_size, + size_t prefill_workspace_size = 0; + CUINFER_CHECK( + cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens, num_heads, num_kv_heads, head_dim, - causal, - scale, q_rope, k_rope, - v_rope)); + v_rope, + cuinfer_data_type, + cuinfer_data_type, + cuinfer_data_type, + &prefill_workspace_size)); - size_t decode_workspace_size = 0; - CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens, - num_heads, - num_kv_heads, - head_dim, - block_size, - max_seq_len, - &decode_workspace_size)); + auto* allocator = paddle::GetAllocator(qkv.place()); - phi::Allocator::AllocationPtr decode_tmp_workspace = allocator->Allocate(decode_workspace_size); - void* decode_workspace_ptr = decode_tmp_workspace->ptr(); + phi::Allocator::AllocationPtr prefill_tmp_workspace = + allocator->Allocate(prefill_workspace_size); + void* prefill_workspace_ptr = prefill_tmp_workspace->ptr(); - void* decode_qkv_ptr = (void*)(qkv.data() + prefill_num_tokens * qkv_stride); - void* decode_out_ptr = (void*)(out.data() + prefill_num_tokens * out.strides()[0]); + CUINFER_CHECK( + cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle, + qkv_desc, + qkv.data(), + qkv_seqlens_desc, + cu_seqlens_qkv.data(), + block_table_desc, + prefill_block_table.data(), + o_desc, + out.data(), + k_cache_desc, + k_cache.data(), + v_cache_desc, + v_cache.data(), + prefill_workspace_ptr, + prefill_workspace_size, + cos_desc, + rope_cos_ptr, + sin_desc, + rope_sin_ptr, + prefill_batch_size, + num_heads, + num_kv_heads, + head_dim, + causal, + scale, + q_rope, + k_rope, + v_rope)); - PageAttentionWithKVCacheArguments args{ - static_cast(scale), 1.0, 1.0, static_cast(softcap), window_left, window_right, - causal, use_sqrt_alibi, enable_cuda_graph, false, nullptr, decode_qkv_ptr, decode_qkv_ptr, - decode_workspace_ptr, true, rope_sin_ptr, rope_cos_ptr}; + size_t decode_workspace_size = 0; + CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens, + num_heads, + num_kv_heads, + head_dim, + block_size, + max_seq_len, + &decode_workspace_size)); - CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle, - decode_out_ptr, - cu_data_type, + phi::Allocator::AllocationPtr decode_tmp_workspace = + allocator->Allocate(decode_workspace_size); + void* decode_workspace_ptr = decode_tmp_workspace->ptr(); + + void* decode_qkv_ptr = + (void*)(qkv.data() + prefill_num_tokens * qkv_stride); + void* decode_out_ptr = + (void*)(out.data() + prefill_num_tokens * out.strides()[0]); + + PageAttentionWithKVCacheArguments args{static_cast(scale), + 1.0, + 1.0, + static_cast(softcap), + window_left, + window_right, + causal, + use_sqrt_alibi, + enable_cuda_graph, + false, + nullptr, decode_qkv_ptr, - cu_data_type, - decode_num_tokens, - num_heads, - num_kv_heads, - head_dim, - qkv_stride, - kv_block_stride, - kv_head_stride, - k_cache.data(), - cu_data_type, - v_cache.data(), - cu_data_type, - block_size, - max_num_blocks_per_seq, - max_seq_len, - decode_block_table.data(), - seq_lens.data(), - args)); + decode_qkv_ptr, + decode_workspace_ptr, + true, + rope_sin_ptr, + rope_cos_ptr}; - CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc)); + CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle, + decode_out_ptr, + cu_data_type, + decode_qkv_ptr, + cu_data_type, + decode_num_tokens, + num_heads, + num_kv_heads, + head_dim, + qkv_stride, + kv_block_stride, + kv_head_stride, + k_cache.data(), + cu_data_type, + v_cache.data(), + cu_data_type, + block_size, + max_num_blocks_per_seq, + max_seq_len, + decode_block_table.data(), + seq_lens.data(), + args)); + + CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc)); } -std::vector MixedFusedPagedAttn(const paddle::Tensor& qkv, - paddle::Tensor& k_cache, - paddle::Tensor& v_cache, - const paddle::Tensor& prefill_block_table, - const paddle::Tensor& decode_block_table, - const paddle::Tensor& cu_seqlens_qkv, - const paddle::Tensor& seq_lens, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, - int prefill_num_tokens, - int num_heads, - int head_dim, - int num_kv_heads, - int block_size, - int max_seq_len, - float scale, - bool causal, - bool q_rope, - bool k_rope, - bool v_rope, - int window_left, - int window_right, - float softcap, - bool enable_cuda_graph, - bool use_sqrt_alibi) { - const auto dtype = qkv.dtype(); - auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place()); +std::vector MixedFusedPagedAttn( + const paddle::Tensor& qkv, + paddle::Tensor& k_cache, + paddle::Tensor& v_cache, + const paddle::Tensor& prefill_block_table, + const paddle::Tensor& decode_block_table, + const paddle::Tensor& cu_seqlens_qkv, + const paddle::Tensor& seq_lens, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, + int prefill_num_tokens, + int num_heads, + int head_dim, + int num_kv_heads, + int block_size, + int max_seq_len, + float scale, + bool causal, + bool q_rope, + bool k_rope, + bool v_rope, + int window_left, + int window_right, + float softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi) { + const auto dtype = qkv.dtype(); + auto out = + paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place()); - switch (dtype) { - case paddle::DataType::BFLOAT16: - MixedFusedPagedAttnKernel(qkv, - k_cache, - v_cache, - prefill_block_table, - decode_block_table, - cu_seqlens_qkv, - seq_lens, - rope_sin, - rope_cos, - prefill_num_tokens, - num_heads, - head_dim, - num_kv_heads, - block_size, - max_seq_len, - scale, - causal, - q_rope, - k_rope, - v_rope, - window_left, - window_right, - softcap, - enable_cuda_graph, - use_sqrt_alibi, - out); - break; - case paddle::DataType::FLOAT16: - MixedFusedPagedAttnKernel(qkv, - k_cache, - v_cache, - prefill_block_table, - decode_block_table, - cu_seqlens_qkv, - seq_lens, - rope_sin, - rope_cos, - prefill_num_tokens, - num_heads, - head_dim, - num_kv_heads, - block_size, - max_seq_len, - scale, - causal, - q_rope, - k_rope, - v_rope, - window_left, - window_right, - softcap, - enable_cuda_graph, - use_sqrt_alibi, - out); - break; - default: - PD_THROW("Unsupported data type for mixed paged attn"); - } - return {out}; + switch (dtype) { + case paddle::DataType::BFLOAT16: + MixedFusedPagedAttnKernel(qkv, + k_cache, + v_cache, + prefill_block_table, + decode_block_table, + cu_seqlens_qkv, + seq_lens, + rope_sin, + rope_cos, + prefill_num_tokens, + num_heads, + head_dim, + num_kv_heads, + block_size, + max_seq_len, + scale, + causal, + q_rope, + k_rope, + v_rope, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + out); + break; + case paddle::DataType::FLOAT16: + MixedFusedPagedAttnKernel(qkv, + k_cache, + v_cache, + prefill_block_table, + decode_block_table, + cu_seqlens_qkv, + seq_lens, + rope_sin, + rope_cos, + prefill_num_tokens, + num_heads, + head_dim, + num_kv_heads, + block_size, + max_seq_len, + scale, + causal, + q_rope, + k_rope, + v_rope, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + out); + break; + default: + PD_THROW("Unsupported data type for mixed paged attn"); + } + return {out}; } -std::vector> MixedFusedPagedAttnInferShape(const std::vector& qkv_shape, - int num_heads, - int head_dim) { - return {{qkv_shape[0], num_heads * head_dim}}; +std::vector> MixedFusedPagedAttnInferShape( + const std::vector& qkv_shape, int num_heads, int head_dim) { + return {{qkv_shape[0], num_heads * head_dim}}; } -std::vector MixedFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) { - return {qkv_dtype}; +std::vector MixedFusedPagedAttnInferDtype( + const paddle::DataType& qkv_dtype) { + return {qkv_dtype}; } PD_BUILD_STATIC_OP(mixed_fused_paged_attn) - .Inputs({"qkv", "k_cache", "v_cache", "prefill_block_table", "decode_block_table", - "cu_seqlens_qkv", "seq_lens", paddle::Optional("rope_sin"), paddle::Optional("rope_cos")}) + .Inputs({"qkv", + "k_cache", + "v_cache", + "prefill_block_table", + "decode_block_table", + "cu_seqlens_qkv", + "seq_lens", + paddle::Optional("rope_sin"), + paddle::Optional("rope_cos")}) .Outputs({"out"}) .Attrs({"prefill_num_tokens:int", "num_heads: int", @@ -362,14 +398,14 @@ PD_BUILD_STATIC_OP(mixed_fused_paged_attn) "block_size:int", "max_seq_len:int", "scale:float", - "causal:bool", - "q_rope:bool", + "causal:bool", + "q_rope:bool", "k_rope:bool", "v_rope:bool", "window_left:int", "window_right:int", "softcap:float", - "enable_cuda_graph:bool", + "enable_cuda_graph:bool", "use_sqrt_alibi:bool"}) .SetKernelFn(PD_KERNEL(MixedFusedPagedAttn)) .SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape)) diff --git a/custom_ops/iluvatar_ops/moe_dispatch.cu b/custom_ops/iluvatar_ops/moe_dispatch.cu index f6e8bb682..bbb4be6e5 100644 --- a/custom_ops/iluvatar_ops/moe_dispatch.cu +++ b/custom_ops/iluvatar_ops/moe_dispatch.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -29,10 +28,10 @@ __global__ void compute_total_rows_before_expert_kernel( const int64_t sorted_experts_len, const int64_t num_experts, int64_t* total_rows_before_expert) { - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - if (expert >= num_experts) return; - total_rows_before_expert[expert] = - phi::find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + total_rows_before_expert[expert] = phi::find_total_elts_leq_target( + sorted_experts, sorted_experts_len, expert); } void compute_total_rows_before_expert(int* sorted_indices, @@ -40,36 +39,38 @@ void compute_total_rows_before_expert(int* sorted_indices, const int64_t num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) { - const int threads = std::min(int64_t(1024), num_experts); - const int blocks = (num_experts + threads - 1) / threads; + const int threads = std::min(int64_t(1024), num_experts); + const int blocks = (num_experts + threads - 1) / threads; - compute_total_rows_before_expert_kernel<<>>( - sorted_indices, total_indices, num_experts, total_rows_before_expert); + compute_total_rows_before_expert_kernel<<>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); } template -void MoeDispatchKernel(const paddle::Tensor& input, - const paddle::Tensor& gating_output, - const paddle::optional& gating_correction_bias, - const int moe_topk, - const bool group_moe, - const std::string &moe_quant_type, - const bool topk_only_mode, - const int num_rows, - const int hidden_size, - const int expert_num, - paddle::Tensor* permute_input, - paddle::Tensor* tokens_expert_prefix_sum, - paddle::Tensor* permute_indices_per_token, - paddle::Tensor* top_k_weight, - paddle::Tensor* top_k_indices) { +void MoeDispatchKernel( + const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const paddle::optional& gating_correction_bias, + const int moe_topk, + const bool group_moe, + const std::string& moe_quant_type, + const bool topk_only_mode, + const int num_rows, + const int hidden_size, + const int expert_num, + paddle::Tensor* permute_input, + paddle::Tensor* tokens_expert_prefix_sum, + paddle::Tensor* permute_indices_per_token, + paddle::Tensor* top_k_weight, + paddle::Tensor* top_k_indices) { using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto place = input.place(); - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input.place())); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(input.place())); auto stream = static_cast(dev_ctx->stream()); if (group_moe) { // Check if expert_num is divisible by moe_topk, else throw an error @@ -131,19 +132,21 @@ void MoeDispatchKernel(const paddle::Tensor& input, softmax_out_ = nullptr; } - topk_gating_softmax_kernelLauncher(gating_output.data(), - gating_correction_bias ? gating_correction_bias.get().data() : nullptr, - top_k_weight->data(), - softmax_out_, - expert_for_source_row, - source_rows_, - softmax_max_prob, - num_rows, - expert_num, - moe_topk, - group_moe, - stream, - topk_only_mode); + topk_gating_softmax_kernelLauncher( + gating_output.data(), + gating_correction_bias ? gating_correction_bias.get().data() + : nullptr, + top_k_weight->data(), + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + expert_num, + moe_topk, + group_moe, + stream, + topk_only_mode); sorter_.run(reinterpret_cast(sorter_ws_ptr), sorter_ws_size_bytes, @@ -155,7 +158,6 @@ void MoeDispatchKernel(const paddle::Tensor& input, false, stream); - initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), @@ -167,16 +169,13 @@ void MoeDispatchKernel(const paddle::Tensor& input, moe_topk, stream); - - compute_total_rows_before_expert( - permuted_experts_, - moe_topk * num_rows, - expert_num, - tokens_expert_prefix_sum->data(), - stream); + compute_total_rows_before_expert(permuted_experts_, + moe_topk * num_rows, + expert_num, + tokens_expert_prefix_sum->data(), + stream); } - std::vector MoeExpertDispatch( const paddle::Tensor& input, const paddle::Tensor& gating_output, @@ -184,7 +183,7 @@ std::vector MoeExpertDispatch( const paddle::optional& w4a8_in_scale, const int moe_topk, const bool group_moe, - const std::string &moe_quant_type, + const std::string& moe_quant_type, const bool topk_only_mode) { const auto input_type = input.dtype(); auto place = input.place(); @@ -214,7 +213,6 @@ std::vector MoeExpertDispatch( auto permute_indices_per_token = GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); - switch (input_type) { case paddle::DataType::BFLOAT16: MoeDispatchKernel(input, @@ -261,7 +259,6 @@ std::vector MoeExpertDispatch( top_k_indices}; } - std::vector> MoeExpertDispatchInferShape( const std::vector& input_shape, const std::vector& gating_output_shape, @@ -299,17 +296,21 @@ std::vector MoeExpertDispatchInferDtype( paddle::DataType::INT32}; } - PD_BUILD_STATIC_OP(moe_expert_dispatch) - .Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"), - paddle::Optional("w4a8_in_scale")}) + .Inputs({"input", + "gating_output", + paddle::Optional("gating_correction_bias"), + paddle::Optional("w4a8_in_scale")}) .Outputs({"permute_input", "tokens_expert_prefix_sum", "permute_indices_per_token", "top_k_weight", "top_k_indices", "expert_idx_per_token"}) - .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"}) + .Attrs({"moe_topk:int", + "group_moe:bool", + "moe_quant_type:std::string", + "topk_only_mode:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype)); diff --git a/custom_ops/iluvatar_ops/moe_reduce.cu b/custom_ops/iluvatar_ops/moe_reduce.cu index 8e58db47d..95cad5517 100644 --- a/custom_ops/iluvatar_ops/moe_reduce.cu +++ b/custom_ops/iluvatar_ops/moe_reduce.cu @@ -16,9 +16,9 @@ #pragma once -#include "helper.h" #include "fused_moe_helper.h" #include "fused_moe_op.h" +#include "helper.h" template void MoeReduceKernel(const paddle::Tensor& ffn_out, @@ -32,27 +32,28 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out, const int hidden_size, const int topk, paddle::Tensor* output) { - using namespace phi; - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place())); - auto stream = static_cast(dev_ctx->stream()); + using namespace phi; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(ffn_out.place())); + auto stream = static_cast(dev_ctx->stream()); - finalize_moe_routing_kernelLauncher( - ffn_out.data(), - output->data(), - down_proj_bias ? down_proj_bias->data() : nullptr, - top_k_weight.data(), - permute_indices_per_token.data(), - top_k_indices.data(), - num_rows, - hidden_size, - topk, - static_cast(1), - norm_topk_prob, - routed_scaling_factor, - stream); + finalize_moe_routing_kernelLauncher( + ffn_out.data(), + output->data(), + down_proj_bias ? down_proj_bias->data() : nullptr, + top_k_weight.data(), + permute_indices_per_token.data(), + top_k_indices.data(), + num_rows, + hidden_size, + topk, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); } paddle::Tensor MoeExpertReduceFunc( @@ -63,48 +64,46 @@ paddle::Tensor MoeExpertReduceFunc( const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { - const auto input_type = ffn_out.dtype(); - auto place = ffn_out.place(); + const auto input_type = ffn_out.dtype(); + auto place = ffn_out.place(); - const int topk = top_k_indices.dims()[1]; - const int num_rows = ffn_out.dims()[0] / topk; - const int hidden_size = ffn_out.dims()[1]; + const int topk = top_k_indices.dims()[1]; + const int num_rows = ffn_out.dims()[0] / topk; + const int hidden_size = ffn_out.dims()[1]; - auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); + auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); - switch (input_type) { - case paddle::DataType::BFLOAT16: - MoeReduceKernel( - ffn_out, - top_k_weight, - permute_indices_per_token, - top_k_indices, - down_proj_bias, - norm_topk_prob, - routed_scaling_factor, - num_rows, - hidden_size, - topk, - &output); - break; - case paddle::DataType::FLOAT16: - MoeReduceKernel( - ffn_out, - top_k_weight, - permute_indices_per_token, - top_k_indices, - down_proj_bias, - norm_topk_prob, - routed_scaling_factor, - num_rows, - hidden_size, - topk, - &output); - break; - default: - PD_THROW("Unsupported data type for MoeDispatchKernel"); - } - return output; + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeReduceKernel(ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + down_proj_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + topk, + &output); + break; + case paddle::DataType::FLOAT16: + MoeReduceKernel(ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + down_proj_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + topk, + &output); + break; + default: + PD_THROW("Unsupported data type for MoeDispatchKernel"); + } + return output; } std::vector MoeExpertReduce( @@ -115,13 +114,13 @@ std::vector MoeExpertReduce( const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { - return {MoeExpertReduceFunc(ffn_out, - top_k_weight, - permute_indices_per_token, - top_k_indices, - down_proj_bias, - norm_topk_prob, - routed_scaling_factor)}; + return {MoeExpertReduceFunc(ffn_out, + top_k_weight, + permute_indices_per_token, + top_k_indices, + down_proj_bias, + norm_topk_prob, + routed_scaling_factor)}; } std::vector> MoeExpertReduceInferShape( @@ -130,7 +129,7 @@ std::vector> MoeExpertReduceInferShape( const std::vector& permute_indices_per_token_shape, const std::vector& top_k_indices_shape, const paddle::optional>& down_proj_bias_shape) { - return {ffn_out_shape}; + return {ffn_out_shape}; } std::vector MoeExpertReduceInferDtype( @@ -139,7 +138,7 @@ std::vector MoeExpertReduceInferDtype( const paddle::DataType& permute_indices_per_token_dtype, const paddle::DataType& top_k_indices_dtype, const paddle::optional& down_proj_bias_dtype) { - return {ffn_out_dtype}; + return {ffn_out_dtype}; } PD_BUILD_STATIC_OP(moe_expert_reduce) diff --git a/custom_ops/iluvatar_ops/paged_attn.cu b/custom_ops/iluvatar_ops/paged_attn.cu index ca1ddde72..9d2c19a17 100644 --- a/custom_ops/iluvatar_ops/paged_attn.cu +++ b/custom_ops/iluvatar_ops/paged_attn.cu @@ -15,18 +15,17 @@ #include "helper.h" #include "iluvatar_context.h" - template void PagedAttnKernel(const paddle::Tensor& q, const paddle::Tensor& k_cache, const paddle::Tensor& v_cache, const paddle::Tensor& block_table, const paddle::Tensor& seq_lens, - const paddle::optional &alibi_slopes, - const paddle::optional &k, - const paddle::optional &v, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, + const paddle::optional& alibi_slopes, + const paddle::optional& k, + const paddle::optional& v, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, int num_heads, int head_dim, int num_kv_heads, @@ -41,298 +40,326 @@ void PagedAttnKernel(const paddle::Tensor& q, bool use_sqrt_alibi, bool merged_qkv, paddle::Tensor& out) { - if (alibi_slopes) { - PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(), - paddle::DataType::FLOAT32, - common::errors::InvalidArgument( - "paged_attention expects alibi_slopes float tensor")); - PADDLE_ENFORCE_EQ(alibi_slopes.get().is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects alibi_slopes is contiguous")); - } + if (alibi_slopes) { + PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(), + paddle::DataType::FLOAT32, + common::errors::InvalidArgument( + "paged_attention expects alibi_slopes float tensor")); + PADDLE_ENFORCE_EQ( + alibi_slopes.get().is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects alibi_slopes is contiguous")); + } - // check dtype and contiguous - const auto& dtype = q.dtype(); - cudaDataType_t data_type; - if (dtype == paddle::DataType::FLOAT16) { - data_type = CUDA_R_16F; - } else if (dtype == paddle::DataType::BFLOAT16) { - data_type = CUDA_R_16BF; - } else { - common::errors::InvalidArgument("paged_attention support half and bfloat16 now"); - } + // check dtype and contiguous + const auto& dtype = q.dtype(); + cudaDataType_t data_type; + if (dtype == paddle::DataType::FLOAT16) { + data_type = CUDA_R_16F; + } else if (dtype == paddle::DataType::BFLOAT16) { + data_type = CUDA_R_16BF; + } else { + common::errors::InvalidArgument( + "paged_attention support half and bfloat16 now"); + } - PADDLE_ENFORCE_EQ(k_cache.dtype(), - dtype, - common::errors::InvalidArgument( - "k_cache dtype must be the same as query dtype")); - PADDLE_ENFORCE_EQ(k_cache.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects k_cache is contiguous")); - PADDLE_ENFORCE_EQ(block_table.dtype(), - paddle::DataType::INT32, - common::errors::InvalidArgument( - "block_table dtype must be int32")); - PADDLE_ENFORCE_EQ(block_table.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects block_table is contiguous")); - PADDLE_ENFORCE_EQ(seq_lens.dtype(), - paddle::DataType::INT32, - common::errors::InvalidArgument( - "seq_lens dtype must be int32")); - PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects seq_lens is contiguous")); - // check dim and shape - // k_cache: [num_blocks, kv_num_heads, block_size, head_dim] - // v_cache: [num_blocks, kv_num_heads, block_size, head_dim] - // block_table: [num_seqs, max_num_blocks_per_seq] - // seq_lens: [num_seqs] - // q and out: - // if merged_qkv = false: - // q:[num_seqs, hidden_size] - // out:[num_seqs, hidden_size] - // if merged_qkv = true: - // q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim] - // out: [num_seqs, hidden_size] + PADDLE_ENFORCE_EQ(k_cache.dtype(), + dtype, + common::errors::InvalidArgument( + "k_cache dtype must be the same as query dtype")); + PADDLE_ENFORCE_EQ(k_cache.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects k_cache is contiguous")); + PADDLE_ENFORCE_EQ( + block_table.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument("block_table dtype must be int32")); + PADDLE_ENFORCE_EQ(block_table.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects block_table is contiguous")); + PADDLE_ENFORCE_EQ( + seq_lens.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument("seq_lens dtype must be int32")); + PADDLE_ENFORCE_EQ(seq_lens.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects seq_lens is contiguous")); + // check dim and shape + // k_cache: [num_blocks, kv_num_heads, block_size, head_dim] + // v_cache: [num_blocks, kv_num_heads, block_size, head_dim] + // block_table: [num_seqs, max_num_blocks_per_seq] + // seq_lens: [num_seqs] + // q and out: + // if merged_qkv = false: + // q:[num_seqs, hidden_size] + // out:[num_seqs, hidden_size] + // if merged_qkv = true: + // q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim] + // out: [num_seqs, hidden_size] - const auto& q_dims = q.dims(); - PADDLE_ENFORCE_EQ(q_dims.size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive query dims is " - "[num_seqs, (num_heads+2*num_kv_heads)*head_dim]")); - PADDLE_ENFORCE_EQ(out.dims().size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive out dims is " - "[num_seqs, hidden_size]")); + const auto& q_dims = q.dims(); + PADDLE_ENFORCE_EQ(q_dims.size(), + 2, + common::errors::InvalidArgument( + "paged_attn receive query dims is " + "[num_seqs, (num_heads+2*num_kv_heads)*head_dim]")); + PADDLE_ENFORCE_EQ( + out.dims().size(), + 2, + common::errors::InvalidArgument("paged_attn receive out dims is " + "[num_seqs, hidden_size]")); - const auto& kv_cache_dims = k_cache.dims(); - PADDLE_ENFORCE_EQ(kv_cache_dims.size(), - 4, - common::errors::InvalidArgument( - "paged_attn receive kv cache dims is " - "[num_blocks, kv_num_heads, block_size, head_dim]")); + const auto& kv_cache_dims = k_cache.dims(); + PADDLE_ENFORCE_EQ(kv_cache_dims.size(), + 4, + common::errors::InvalidArgument( + "paged_attn receive kv cache dims is " + "[num_blocks, kv_num_heads, block_size, head_dim]")); - const auto& block_table_dims = block_table.dims(); - PADDLE_ENFORCE_EQ(block_table_dims.size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive block_table dims is " - "[num_seqs, max_num_blocks_per_seq]")); + const auto& block_table_dims = block_table.dims(); + PADDLE_ENFORCE_EQ( + block_table_dims.size(), + 2, + common::errors::InvalidArgument("paged_attn receive block_table dims is " + "[num_seqs, max_num_blocks_per_seq]")); - const auto& seq_lens_dims = seq_lens.dims(); - PADDLE_ENFORCE_EQ(seq_lens_dims.size(), - 1, - common::errors::InvalidArgument( - "paged_attn receive seq_lens dims is [num_seqs]")); + const auto& seq_lens_dims = seq_lens.dims(); + PADDLE_ENFORCE_EQ(seq_lens_dims.size(), + 1, + common::errors::InvalidArgument( + "paged_attn receive seq_lens dims is [num_seqs]")); - int num_seqs = q_dims[0]; - int max_num_blocks_per_seq = block_table_dims[1]; - int q_stride = q.strides()[0]; - int num_blocks = kv_cache_dims[0]; + int num_seqs = q_dims[0]; + int max_num_blocks_per_seq = block_table_dims[1]; + int q_stride = q.strides()[0]; + int num_blocks = kv_cache_dims[0]; - PADDLE_ENFORCE_EQ(kv_cache_dims[1], - num_kv_heads, - common::errors::InvalidArgument( - "kv_cache_dims[1] must be equal to num_kv_head")); - PADDLE_ENFORCE_EQ(kv_cache_dims[2], - block_size, - common::errors::InvalidArgument( - "kv_cache_dims[2] must be equal to block_size")); - PADDLE_ENFORCE_EQ(kv_cache_dims[3], - head_dim, - common::errors::InvalidArgument( - "kv_cache_dims[3] must be equal to head_dim")); - PADDLE_ENFORCE_EQ(block_table_dims[0], - num_seqs, - common::errors::InvalidArgument( - "block_table_dims[0] must be equal to num_seqs")); - PADDLE_ENFORCE_EQ(seq_lens_dims[0], - num_seqs, - common::errors::InvalidArgument( - "seq_lens_dims[0] must be equal to num_seqs")); + PADDLE_ENFORCE_EQ(kv_cache_dims[1], + num_kv_heads, + common::errors::InvalidArgument( + "kv_cache_dims[1] must be equal to num_kv_head")); + PADDLE_ENFORCE_EQ(kv_cache_dims[2], + block_size, + common::errors::InvalidArgument( + "kv_cache_dims[2] must be equal to block_size")); + PADDLE_ENFORCE_EQ(kv_cache_dims[3], + head_dim, + common::errors::InvalidArgument( + "kv_cache_dims[3] must be equal to head_dim")); + PADDLE_ENFORCE_EQ(block_table_dims[0], + num_seqs, + common::errors::InvalidArgument( + "block_table_dims[0] must be equal to num_seqs")); + PADDLE_ENFORCE_EQ(seq_lens_dims[0], + num_seqs, + common::errors::InvalidArgument( + "seq_lens_dims[0] must be equal to num_seqs")); - int kv_block_stride = k_cache.strides()[0]; - int kv_head_stride = k_cache.strides()[1]; - const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data() : nullptr; - const void *key_ptr = k ? k.get().data() : nullptr; - const void *value_ptr = v ? v.get().data() : nullptr; - const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data() : nullptr; - const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data() : nullptr; + int kv_block_stride = k_cache.strides()[0]; + int kv_head_stride = k_cache.strides()[1]; + const float* alibi_slopes_ptr = + alibi_slopes ? alibi_slopes.get().data() : nullptr; + const void* key_ptr = k ? k.get().data() : nullptr; + const void* value_ptr = v ? v.get().data() : nullptr; + const float* rope_sin_ptr = + merged_qkv ? rope_sin.get().data() : nullptr; + const float* rope_cos_ptr = + merged_qkv ? rope_cos.get().data() : nullptr; - cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle(); + cuinferHandle_t cuinfer_handle = + iluvatar::getContextInstance()->getIxInferHandle(); - size_t workspace_size = 0; - CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs, - num_heads, - num_kv_heads, - head_dim, - block_size, - max_context_len, - &workspace_size)); - auto* allocator = paddle::GetAllocator(q.place()); - phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size); - void* workspace_ptr = tmp_workspace->ptr(); + size_t workspace_size = 0; + CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs, + num_heads, + num_kv_heads, + head_dim, + block_size, + max_context_len, + &workspace_size)); + auto* allocator = paddle::GetAllocator(q.place()); + phi::Allocator::AllocationPtr tmp_workspace = + allocator->Allocate(workspace_size); + void* workspace_ptr = tmp_workspace->ptr(); - PageAttentionWithKVCacheArguments args{ - static_cast(scale), 1.0, 1.0, static_cast(softcap), window_left, window_right, - causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, - workspace_ptr, merged_qkv, rope_sin_ptr, rope_cos_ptr}; - CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle, - out.data(), - data_type, - q.data(), - data_type, - num_seqs, - num_heads, - num_kv_heads, - head_dim, - q_stride, - kv_block_stride, - kv_head_stride, - k_cache.data(), - data_type, - v_cache.data(), - data_type, - block_size, - max_num_blocks_per_seq, - max_context_len, - block_table.data(), - seq_lens.data(), - args)); + PageAttentionWithKVCacheArguments args{static_cast(scale), + 1.0, + 1.0, + static_cast(softcap), + window_left, + window_right, + causal, + use_sqrt_alibi, + enable_cuda_graph, + false, + alibi_slopes_ptr, + key_ptr, + value_ptr, + workspace_ptr, + merged_qkv, + rope_sin_ptr, + rope_cos_ptr}; + CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle, + out.data(), + data_type, + q.data(), + data_type, + num_seqs, + num_heads, + num_kv_heads, + head_dim, + q_stride, + kv_block_stride, + kv_head_stride, + k_cache.data(), + data_type, + v_cache.data(), + data_type, + block_size, + max_num_blocks_per_seq, + max_context_len, + block_table.data(), + seq_lens.data(), + args)); } -std::vector PagedAttn(const paddle::Tensor& q, - const paddle::Tensor& k_cache, - const paddle::Tensor& v_cache, - const paddle::Tensor& block_table, - const paddle::Tensor& seq_lens, - const paddle::optional &alibi_slopes, - const paddle::optional &k, - const paddle::optional &v, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, - int num_heads, - int head_dim, - int num_kv_heads, - float scale, - int block_size, - int max_context_len, - bool causal, - int window_left, - int window_right, - float softcap, - bool enable_cuda_graph, - bool use_sqrt_alibi, - bool merged_qkv) { +std::vector PagedAttn( + const paddle::Tensor& q, + const paddle::Tensor& k_cache, + const paddle::Tensor& v_cache, + const paddle::Tensor& block_table, + const paddle::Tensor& seq_lens, + const paddle::optional& alibi_slopes, + const paddle::optional& k, + const paddle::optional& v, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, + int num_heads, + int head_dim, + int num_kv_heads, + float scale, + int block_size, + int max_context_len, + bool causal, + int window_left, + int window_right, + float softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi, + bool merged_qkv) { + const auto dtype = q.dtype(); + auto out = + paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place()); - const auto dtype = q.dtype(); - auto out = paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place()); - - switch (dtype) { - case paddle::DataType::BFLOAT16: - PagedAttnKernel(q, - k_cache, - v_cache, - block_table, - seq_lens, - alibi_slopes, - k, - v, - rope_sin, - rope_cos, - num_heads, - head_dim, - num_kv_heads, - scale, - block_size, - max_context_len, - causal, - window_left, - window_right, - softcap, - enable_cuda_graph, - use_sqrt_alibi, - merged_qkv, - out); - break; - case paddle::DataType::FLOAT16: - PagedAttnKernel(q, - k_cache, - v_cache, - block_table, - seq_lens, - alibi_slopes, - k, - v, - rope_sin, - rope_cos, - num_heads, - head_dim, - num_kv_heads, - scale, - block_size, - max_context_len, - causal, - window_left, - window_right, - softcap, - enable_cuda_graph, - use_sqrt_alibi, - merged_qkv, - out); - break; - default: - PD_THROW("Unsupported data type for Paged attn"); - } - return {out}; + switch (dtype) { + case paddle::DataType::BFLOAT16: + PagedAttnKernel(q, + k_cache, + v_cache, + block_table, + seq_lens, + alibi_slopes, + k, + v, + rope_sin, + rope_cos, + num_heads, + head_dim, + num_kv_heads, + scale, + block_size, + max_context_len, + causal, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + merged_qkv, + out); + break; + case paddle::DataType::FLOAT16: + PagedAttnKernel(q, + k_cache, + v_cache, + block_table, + seq_lens, + alibi_slopes, + k, + v, + rope_sin, + rope_cos, + num_heads, + head_dim, + num_kv_heads, + scale, + block_size, + max_context_len, + causal, + window_left, + window_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + merged_qkv, + out); + break; + default: + PD_THROW("Unsupported data type for Paged attn"); + } + return {out}; } -std::vector> PagedAttnInferShape(const std::vector& q_shape, - const std::vector& k_cache_shape, - const std::vector& v_cache_shape, - const std::vector& block_table_shape, - const std::vector& seq_lens_shape, - const std::vector& alibi_slopes_shape, - const std::vector& k_shape, - const std::vector& v_shape, - const std::vector& rope_sin_shape, - const std::vector& rope_cos_shape, - int num_heads, - int head_dim, - int num_kv_heads, - float scale, - int block_size, - int max_context_len, - bool causal, - int window_left, - int window_right, - float softcap, - bool enable_cuda_graph, - bool use_sqrt_alibi, - bool merged_qkv) { - if (merged_qkv) { - return {{q_shape[0], num_heads * head_dim}}; - } else { - return {q_shape}; - } +std::vector> PagedAttnInferShape( + const std::vector& q_shape, + const std::vector& k_cache_shape, + const std::vector& v_cache_shape, + const std::vector& block_table_shape, + const std::vector& seq_lens_shape, + const std::vector& alibi_slopes_shape, + const std::vector& k_shape, + const std::vector& v_shape, + const std::vector& rope_sin_shape, + const std::vector& rope_cos_shape, + int num_heads, + int head_dim, + int num_kv_heads, + float scale, + int block_size, + int max_context_len, + bool causal, + int window_left, + int window_right, + float softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi, + bool merged_qkv) { + if (merged_qkv) { + return {{q_shape[0], num_heads * head_dim}}; + } else { + return {q_shape}; + } } -std::vector PagedAttnInferDtype(const paddle::DataType& q_dtype) { - return {q_dtype}; +std::vector PagedAttnInferDtype( + const paddle::DataType& q_dtype) { + return {q_dtype}; } - PD_BUILD_STATIC_OP(paged_attn) - .Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", - paddle::Optional("alibi_slopes"), paddle::Optional("k"), - paddle::Optional("v"), paddle::Optional("rope_sin"), + .Inputs({"q", + "k_cache", + "v_cache", + "block_table", + "seq_lens", + paddle::Optional("alibi_slopes"), + paddle::Optional("k"), + paddle::Optional("v"), + paddle::Optional("rope_sin"), paddle::Optional("rope_cos")}) .Outputs({"out"}) .Attrs({"num_heads:int", @@ -341,11 +368,11 @@ PD_BUILD_STATIC_OP(paged_attn) "scale:float", "block_size:int", "max_context_len:int", - "causal:bool", + "causal:bool", "window_left:int", "window_right:int", "softcap:float", - "enable_cuda_graph:bool", + "enable_cuda_graph:bool", "use_sqrt_alibi:bool", "merged_qkv:bool"}) .SetKernelFn(PD_KERNEL(PagedAttn)) diff --git a/custom_ops/iluvatar_ops/prefill_fused_attn.cu b/custom_ops/iluvatar_ops/prefill_fused_attn.cu index 64251ba59..fe8449c40 100644 --- a/custom_ops/iluvatar_ops/prefill_fused_attn.cu +++ b/custom_ops/iluvatar_ops/prefill_fused_attn.cu @@ -16,352 +16,374 @@ #include "iluvatar_context.h" template -void PrefillFusedPagedAttnKernel(const paddle::Tensor& qkv, - paddle::Tensor& k_cache, - paddle::Tensor& v_cache, - const paddle::Tensor& block_table, - const paddle::Tensor& cu_seqlens_qkv, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, - int num_heads, - int head_dim, - int num_kv_heads, - int block_size, - int max_seq_len, - float scale, - bool causal, - bool q_rope, - bool k_rope, - bool v_rope, - paddle::Tensor& out) { +void PrefillFusedPagedAttnKernel( + const paddle::Tensor& qkv, + paddle::Tensor& k_cache, + paddle::Tensor& v_cache, + const paddle::Tensor& block_table, + const paddle::Tensor& cu_seqlens_qkv, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, + int num_heads, + int head_dim, + int num_kv_heads, + int block_size, + int max_seq_len, + float scale, + bool causal, + bool q_rope, + bool k_rope, + bool v_rope, + paddle::Tensor& out) { + // check dtype and contiguous + const auto& dtype = qkv.dtype(); + cuinferDataType_t data_type; + if (dtype == paddle::DataType::FLOAT16) { + data_type = CUINFER_DATA_HALF; - // check dtype and contiguous - const auto& dtype = qkv.dtype(); - cuinferDataType_t data_type; - if (dtype == paddle::DataType::FLOAT16) { - data_type = CUINFER_DATA_HALF; + } else if (dtype == paddle::DataType::BFLOAT16) { + data_type = CUINFER_DATA_BFLOAT16; + } else { + common::errors::InvalidArgument( + "paged_attention support half and bfloat16 now"); + } - } else if (dtype == paddle::DataType::BFLOAT16) { - data_type = CUINFER_DATA_BFLOAT16; - } else { - common::errors::InvalidArgument("paged_attention support half and bfloat16 now"); - } + PADDLE_ENFORCE_EQ(k_cache.dtype(), + dtype, + common::errors::InvalidArgument( + "k_cache dtype must be the same as query dtype")); + PADDLE_ENFORCE_EQ(k_cache.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects k_cache is contiguous")); + PADDLE_ENFORCE_EQ( + block_table.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument("block_table dtype must be int32")); + PADDLE_ENFORCE_EQ(block_table.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects block_table is contiguous")); + PADDLE_ENFORCE_EQ( + cu_seqlens_qkv.dtype(), + paddle::DataType::INT32, + common::errors::InvalidArgument("cu_seqlens_qkv dtype must be int32")); + PADDLE_ENFORCE_EQ( + cu_seqlens_qkv.is_contiguous(), + true, + common::errors::InvalidArgument( + "paged_attention expects cu_seqlens_qkv is contiguous")); + // check dim and shape + // k_cache: [num_blocks, kv_num_heads, block_size, head_dim] + // v_cache: [num_blocks, kv_num_heads, block_size, head_dim] + // block_table: [batch_size, max_num_blocks_per_seq] + // seq_lens: [batch_size] + // qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim] + // out: [num_tokens, hidden_size] - PADDLE_ENFORCE_EQ(k_cache.dtype(), - dtype, - common::errors::InvalidArgument( - "k_cache dtype must be the same as query dtype")); - PADDLE_ENFORCE_EQ(k_cache.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects k_cache is contiguous")); - PADDLE_ENFORCE_EQ(block_table.dtype(), - paddle::DataType::INT32, - common::errors::InvalidArgument( - "block_table dtype must be int32")); - PADDLE_ENFORCE_EQ(block_table.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects block_table is contiguous")); - PADDLE_ENFORCE_EQ(cu_seqlens_qkv.dtype(), - paddle::DataType::INT32, - common::errors::InvalidArgument( - "cu_seqlens_qkv dtype must be int32")); - PADDLE_ENFORCE_EQ(cu_seqlens_qkv.is_contiguous(), - true, - common::errors::InvalidArgument( - "paged_attention expects cu_seqlens_qkv is contiguous")); - // check dim and shape - // k_cache: [num_blocks, kv_num_heads, block_size, head_dim] - // v_cache: [num_blocks, kv_num_heads, block_size, head_dim] - // block_table: [batch_size, max_num_blocks_per_seq] - // seq_lens: [batch_size] - // qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim] - // out: [num_tokens, hidden_size] + const auto& qkv_dims = qkv.dims(); + PADDLE_ENFORCE_EQ(qkv_dims.size(), + 2, + common::errors::InvalidArgument( + "paged_attn receive query dims is " + "[num_tokens, (num_heads+2*num_kv_heads)*head_dim]")); + PADDLE_ENFORCE_EQ( + out.dims().size(), + 2, + common::errors::InvalidArgument("paged_attn receive out dims is " + "[num_tokens, hidden_size]")); - const auto& qkv_dims = qkv.dims(); - PADDLE_ENFORCE_EQ(qkv_dims.size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive query dims is " - "[num_tokens, (num_heads+2*num_kv_heads)*head_dim]")); - PADDLE_ENFORCE_EQ(out.dims().size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive out dims is " - "[num_tokens, hidden_size]")); + const auto& kv_cache_dims = k_cache.dims(); + PADDLE_ENFORCE_EQ(kv_cache_dims.size(), + 4, + common::errors::InvalidArgument( + "paged_attn receive kv cache dims is " + "[num_blocks, kv_num_heads, block_size, head_dim]")); - const auto& kv_cache_dims = k_cache.dims(); - PADDLE_ENFORCE_EQ(kv_cache_dims.size(), - 4, - common::errors::InvalidArgument( - "paged_attn receive kv cache dims is " - "[num_blocks, kv_num_heads, block_size, head_dim]")); + const auto& block_table_dims = block_table.dims(); + PADDLE_ENFORCE_EQ( + block_table_dims.size(), + 2, + common::errors::InvalidArgument("paged_attn receive block_table dims is " + "[batch_size, max_num_blocks_per_seq]")); - const auto& block_table_dims = block_table.dims(); - PADDLE_ENFORCE_EQ(block_table_dims.size(), - 2, - common::errors::InvalidArgument( - "paged_attn receive block_table dims is " - "[batch_size, max_num_blocks_per_seq]")); + const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims(); + PADDLE_ENFORCE_EQ( + cu_seqlens_qkv_dims.size(), + 1, + common::errors::InvalidArgument( + "paged_attn receive cu_seqlens_qkv dims is [batch_size]")); - const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims(); - PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims.size(), - 1, - common::errors::InvalidArgument( - "paged_attn receive cu_seqlens_qkv dims is [batch_size]")); + int batch_size = block_table_dims[0]; + int num_tokens = qkv_dims[0]; + int num_total_heads = num_heads + 2 * num_kv_heads; + int qkv_stride = qkv.strides()[0]; + int num_blocks = kv_cache_dims[0]; - int batch_size = block_table_dims[0]; - int num_tokens = qkv_dims[0]; - int num_total_heads = num_heads + 2 * num_kv_heads; - int qkv_stride = qkv.strides()[0]; - int num_blocks = kv_cache_dims[0]; + PADDLE_ENFORCE_EQ(kv_cache_dims[1], + num_kv_heads, + common::errors::InvalidArgument( + "kv_cache_dims[1] must be equal to num_kv_head")); + PADDLE_ENFORCE_EQ(kv_cache_dims[2], + block_size, + common::errors::InvalidArgument( + "kv_cache_dims[2] must be equal to block_size")); + PADDLE_ENFORCE_EQ(kv_cache_dims[3], + head_dim, + common::errors::InvalidArgument( + "kv_cache_dims[3] must be equal to head_dim")); + PADDLE_ENFORCE_EQ( + cu_seqlens_qkv_dims[0], + batch_size + 1, + common::errors::InvalidArgument( + "cu_seqlens_qkv_dims[0] must be equal to batch_size + 1")); - PADDLE_ENFORCE_EQ(kv_cache_dims[1], - num_kv_heads, - common::errors::InvalidArgument( - "kv_cache_dims[1] must be equal to num_kv_head")); - PADDLE_ENFORCE_EQ(kv_cache_dims[2], - block_size, - common::errors::InvalidArgument( - "kv_cache_dims[2] must be equal to block_size")); - PADDLE_ENFORCE_EQ(kv_cache_dims[3], - head_dim, - common::errors::InvalidArgument( - "kv_cache_dims[3] must be equal to head_dim")); - PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims[0], - batch_size + 1, - common::errors::InvalidArgument( - "cu_seqlens_qkv_dims[0] must be equal to batch_size + 1")); + int block_table_stride = block_table.strides()[0]; + const float* rope_sin_ptr = rope_sin ? rope_sin.get().data() : nullptr; + const float* rope_cos_ptr = rope_cos ? rope_cos.get().data() : nullptr; - int block_table_stride = block_table.strides()[0]; - const float *rope_sin_ptr = rope_sin ? rope_sin.get().data() : nullptr; - const float *rope_cos_ptr = rope_cos ? rope_cos.get().data() : nullptr; + cuinferHandle_t cuinfer_handle = + iluvatar::getContextInstance()->getIxInferHandle(); - cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle(); + size_t workspace_size = 0; + CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens, + num_heads, + num_kv_heads, + head_dim, + q_rope, + k_rope, + v_rope, + data_type, + data_type, + data_type, + &workspace_size)); + auto* allocator = paddle::GetAllocator(qkv.place()); + phi::Allocator::AllocationPtr tmp_workspace = + allocator->Allocate(workspace_size); + void* workspace_ptr = tmp_workspace->ptr(); - size_t workspace_size = 0; - CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens, - num_heads, - num_kv_heads, - head_dim, - q_rope, - k_rope, - v_rope, - data_type, - data_type, - data_type, - &workspace_size)); - auto* allocator = paddle::GetAllocator(qkv.place()); - phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size); - void* workspace_ptr = tmp_workspace->ptr(); - - cuinferTensorDescriptor_t qkv_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t qkv_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( qkv_desc, data_type, 3, std::vector({num_tokens, num_total_heads, head_dim}).data(), std::vector({num_total_heads * head_dim, head_dim, 1}).data())); - cuinferTensorDescriptor_t qkv_seqlens_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( - qkv_seqlens_desc, - CUINFER_DATA_INT32, - 1, - std::vector({batch_size + 1}).data(), - std::vector({1}).data())); + cuinferTensorDescriptor_t qkv_seqlens_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc)); + CUINFER_CHECK( + cuinferSetTensorNdDescriptor(qkv_seqlens_desc, + CUINFER_DATA_INT32, + 1, + std::vector({batch_size + 1}).data(), + std::vector({1}).data())); - cuinferTensorDescriptor_t block_table_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t block_table_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( block_table_desc, CUINFER_DATA_INT32, 2, std::vector({batch_size, block_table_stride}).data(), std::vector({block_table_stride, 1}).data())); - cuinferTensorDescriptor_t o_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t o_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( o_desc, data_type, 3, std::vector({num_tokens, num_heads, head_dim}).data(), std::vector({num_heads * head_dim, head_dim, 1}).data())); - cuinferTensorDescriptor_t k_cache_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t k_cache_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( k_cache_desc, data_type, 4, std::vector({num_blocks, num_kv_heads, block_size, head_dim}).data(), - std::vector({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data())); + std::vector({num_kv_heads * block_size * head_dim, + block_size * head_dim, + head_dim, + 1}) + .data())); - cuinferTensorDescriptor_t v_cache_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t v_cache_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( v_cache_desc, data_type, 4, std::vector({num_blocks, num_kv_heads, block_size, head_dim}).data(), - std::vector({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data())); + std::vector({num_kv_heads * block_size * head_dim, + block_size * head_dim, + head_dim, + 1}) + .data())); - cuinferTensorDescriptor_t cos_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t cos_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( cos_desc, CUINFER_DATA_FLOAT, 2, std::vector({max_seq_len, head_dim}).data(), std::vector({head_dim, 1}).data())); - cuinferTensorDescriptor_t sin_desc; - CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc)); - CUINFER_CHECK(cuinferSetTensorNdDescriptor( + cuinferTensorDescriptor_t sin_desc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc)); + CUINFER_CHECK(cuinferSetTensorNdDescriptor( sin_desc, CUINFER_DATA_FLOAT, 2, std::vector({max_seq_len, head_dim}).data(), std::vector({head_dim, 1}).data())); - CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle, - qkv_desc, - qkv.data(), - qkv_seqlens_desc, - cu_seqlens_qkv.data(), - block_table_desc, - block_table.data(), - o_desc, - out.data(), - k_cache_desc, - k_cache.data(), - v_cache_desc, - v_cache.data(), - workspace_ptr, - workspace_size, - cos_desc, - rope_cos_ptr, - sin_desc, - rope_sin_ptr, - batch_size, - num_heads, - num_kv_heads, - head_dim, - causal, - scale, - q_rope, - k_rope, - v_rope)); + CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle, + qkv_desc, + qkv.data(), + qkv_seqlens_desc, + cu_seqlens_qkv.data(), + block_table_desc, + block_table.data(), + o_desc, + out.data(), + k_cache_desc, + k_cache.data(), + v_cache_desc, + v_cache.data(), + workspace_ptr, + workspace_size, + cos_desc, + rope_cos_ptr, + sin_desc, + rope_sin_ptr, + batch_size, + num_heads, + num_kv_heads, + head_dim, + causal, + scale, + q_rope, + k_rope, + v_rope)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc)); - CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc)); + CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc)); } -std::vector PrefillFusedPagedAttn(const paddle::Tensor& qkv, - paddle::Tensor& k_cache, - paddle::Tensor& v_cache, - const paddle::Tensor& block_table, - const paddle::Tensor& cu_seqlens_qkv, - const paddle::optional &rope_sin, - const paddle::optional &rope_cos, - int num_heads, - int head_dim, - int num_kv_heads, - int block_size, - int max_seq_len, - float scale, - bool causal, - bool q_rope, - bool k_rope, - bool v_rope) { +std::vector PrefillFusedPagedAttn( + const paddle::Tensor& qkv, + paddle::Tensor& k_cache, + paddle::Tensor& v_cache, + const paddle::Tensor& block_table, + const paddle::Tensor& cu_seqlens_qkv, + const paddle::optional& rope_sin, + const paddle::optional& rope_cos, + int num_heads, + int head_dim, + int num_kv_heads, + int block_size, + int max_seq_len, + float scale, + bool causal, + bool q_rope, + bool k_rope, + bool v_rope) { + const auto dtype = qkv.dtype(); + auto out = + paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place()); - const auto dtype = qkv.dtype(); - auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place()); - - switch (dtype) { - case paddle::DataType::BFLOAT16: - PrefillFusedPagedAttnKernel(qkv, - k_cache, - v_cache, - block_table, - cu_seqlens_qkv, - rope_sin, - rope_cos, - num_heads, - head_dim, - num_kv_heads, - block_size, - max_seq_len, - scale, - causal, - q_rope, - k_rope, - v_rope, - out); - break; - case paddle::DataType::FLOAT16: - PrefillFusedPagedAttnKernel(qkv, - k_cache, - v_cache, - block_table, - cu_seqlens_qkv, - rope_sin, - rope_cos, - num_heads, - head_dim, - num_kv_heads, - block_size, - max_seq_len, - scale, - causal, - q_rope, - k_rope, - v_rope, - out); - break; - default: - PD_THROW("Unsupported data type for Paged attn"); - } - return {out}; + switch (dtype) { + case paddle::DataType::BFLOAT16: + PrefillFusedPagedAttnKernel(qkv, + k_cache, + v_cache, + block_table, + cu_seqlens_qkv, + rope_sin, + rope_cos, + num_heads, + head_dim, + num_kv_heads, + block_size, + max_seq_len, + scale, + causal, + q_rope, + k_rope, + v_rope, + out); + break; + case paddle::DataType::FLOAT16: + PrefillFusedPagedAttnKernel(qkv, + k_cache, + v_cache, + block_table, + cu_seqlens_qkv, + rope_sin, + rope_cos, + num_heads, + head_dim, + num_kv_heads, + block_size, + max_seq_len, + scale, + causal, + q_rope, + k_rope, + v_rope, + out); + break; + default: + PD_THROW("Unsupported data type for Paged attn"); + } + return {out}; } -std::vector> PrefillFusedPagedAttnInferShape(const std::vector& qkv_shape, - const std::vector& k_cache_shape, - const std::vector& v_cache_shape, - const std::vector& block_table_shape, - const std::vector& cu_seqlens_qkv_shape, - const std::vector& rope_sin_shape, - const std::vector& rope_cos_shape, - int num_heads, - int head_dim, - int num_kv_heads, - int block_size, - int max_seq_len, - float scale, - bool causal, - bool q_rope, - bool k_rope, - bool v_rope) { - return {{qkv_shape[0], num_heads * head_dim}}; +std::vector> PrefillFusedPagedAttnInferShape( + const std::vector& qkv_shape, + const std::vector& k_cache_shape, + const std::vector& v_cache_shape, + const std::vector& block_table_shape, + const std::vector& cu_seqlens_qkv_shape, + const std::vector& rope_sin_shape, + const std::vector& rope_cos_shape, + int num_heads, + int head_dim, + int num_kv_heads, + int block_size, + int max_seq_len, + float scale, + bool causal, + bool q_rope, + bool k_rope, + bool v_rope) { + return {{qkv_shape[0], num_heads * head_dim}}; } -std::vector PrefillFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) { - return {qkv_dtype}; +std::vector PrefillFusedPagedAttnInferDtype( + const paddle::DataType& qkv_dtype) { + return {qkv_dtype}; } PD_BUILD_STATIC_OP(prefill_fused_paged_attn) - .Inputs({"qkv", "k_cache", "v_cache", "block_table", "cu_seqlens_qkv", - paddle::Optional("rope_sin"), paddle::Optional("rope_cos")}) + .Inputs({"qkv", + "k_cache", + "v_cache", + "block_table", + "cu_seqlens_qkv", + paddle::Optional("rope_sin"), + paddle::Optional("rope_cos")}) .Outputs({"out"}) .Attrs({"num_heads:int", "head_dim:int", @@ -369,8 +391,8 @@ PD_BUILD_STATIC_OP(prefill_fused_paged_attn) "block_size:int", "max_seq_len:int", "scale:float", - "causal:bool", - "q_rope:bool", + "causal:bool", + "q_rope:bool", "k_rope:bool", "v_rope:bool"}) .SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn)) diff --git a/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc b/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc index d64f57d11..a0748266a 100644 --- a/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc +++ b/custom_ops/iluvatar_ops/runtime/iluvatar_context.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "iluvatar_context.h" #include diff --git a/custom_ops/iluvatar_ops/runtime/iluvatar_context.h b/custom_ops/iluvatar_ops/runtime/iluvatar_context.h index 80c49bcd5..239030cae 100644 --- a/custom_ops/iluvatar_ops/runtime/iluvatar_context.h +++ b/custom_ops/iluvatar_ops/runtime/iluvatar_context.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,27 +32,26 @@ #include #define CUINFER_CHECK(func) \ - do { \ - cuinferStatus_t status = (func); \ - if (status != CUINFER_STATUS_SUCCESS) { \ - std::cerr << "Error in file " << __FILE__ << " on line " \ - << __LINE__ << ": " << cuinferGetErrorString(status) \ - << std::endl; \ - throw std::runtime_error("CUINFER_CHECK ERROR"); \ - } \ - } while (0) + do { \ + cuinferStatus_t status = (func); \ + if (status != CUINFER_STATUS_SUCCESS) { \ + std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ \ + << ": " << cuinferGetErrorString(status) << std::endl; \ + throw std::runtime_error("CUINFER_CHECK ERROR"); \ + } \ + } while (0) namespace iluvatar { class IluvatarContext { - public: - IluvatarContext() = default; - ~IluvatarContext(); + public: + IluvatarContext() = default; + ~IluvatarContext(); - cuinferHandle_t getIxInferHandle(); + cuinferHandle_t getIxInferHandle(); - private: - cuinferHandle_t ixinfer_handle_{nullptr}; + private: + cuinferHandle_t ixinfer_handle_{nullptr}; }; IluvatarContext* getContextInstance(); diff --git a/custom_ops/iluvatar_ops/w8a16_group_gemm.cu b/custom_ops/iluvatar_ops/w8a16_group_gemm.cu index a9b61b682..54a350713 100644 --- a/custom_ops/iluvatar_ops/w8a16_group_gemm.cu +++ b/custom_ops/iluvatar_ops/w8a16_group_gemm.cu @@ -20,157 +20,157 @@ std::vector GroupGemm(const paddle::Tensor& x, const paddle::Tensor& weight_scale, const paddle::Tensor& prefix_sum, const int32_t group_size) { - auto dev_ctx = static_cast( - paddle::experimental::DeviceContextPool::Instance().Get(x.place())); - auto stream = static_cast(dev_ctx->stream()); - const auto& x_dims = x.dims(); - const auto& w_dims = weight.dims(); - const auto& ws_dims = weight_scale.dims(); - const auto& prefix_sum_dims = prefix_sum.dims(); - // [m, k] - PD_CHECK(x_dims.size() == 2, "x should be 2D"); - // [n_experts, n, k] - PD_CHECK(w_dims.size() == 3, "weight should be 3D"); - // [n_experts, n] - PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D"); - // [n_experts] - PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D"); - PD_CHECK(group_size == -1); - auto m = x_dims[0]; - auto k = x_dims[1]; - auto n_experts = w_dims[0]; - auto n = w_dims[1]; - PD_CHECK(w_dims[2] == k); - PD_CHECK(ws_dims[0] == n_experts); - PD_CHECK(ws_dims[1] == n); - PD_CHECK(prefix_sum_dims[0] == n_experts); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(x.place())); + auto stream = static_cast(dev_ctx->stream()); + const auto& x_dims = x.dims(); + const auto& w_dims = weight.dims(); + const auto& ws_dims = weight_scale.dims(); + const auto& prefix_sum_dims = prefix_sum.dims(); + // [m, k] + PD_CHECK(x_dims.size() == 2, "x should be 2D"); + // [n_experts, n, k] + PD_CHECK(w_dims.size() == 3, "weight should be 3D"); + // [n_experts, n] + PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D"); + // [n_experts] + PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D"); + PD_CHECK(group_size == -1); + auto m = x_dims[0]; + auto k = x_dims[1]; + auto n_experts = w_dims[0]; + auto n = w_dims[1]; + PD_CHECK(w_dims[2] == k); + PD_CHECK(ws_dims[0] == n_experts); + PD_CHECK(ws_dims[1] == n); + PD_CHECK(prefix_sum_dims[0] == n_experts); - PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64); - PD_CHECK(prefix_sum.is_cpu()); - PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 || - x.dtype() == paddle::DataType::FLOAT16); - PD_CHECK(weight.dtype() == paddle::DataType::INT8); - PD_CHECK(weight_scale.dtype() == x.dtype()); - PD_CHECK(x.is_contiguous()); - PD_CHECK(weight.is_contiguous()); - PD_CHECK(weight_scale.is_contiguous()); + PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64); + PD_CHECK(prefix_sum.is_cpu()); + PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 || + x.dtype() == paddle::DataType::FLOAT16); + PD_CHECK(weight.dtype() == paddle::DataType::INT8); + PD_CHECK(weight_scale.dtype() == x.dtype()); + PD_CHECK(x.is_contiguous()); + PD_CHECK(weight.is_contiguous()); + PD_CHECK(weight_scale.is_contiguous()); - const int64_t* prefix_sum_ptr = prefix_sum.data(); - auto output = GetEmptyTensor({m, n}, x.dtype(), x.place()); - int16_t* out_data = static_cast(output.data()); - const int16_t* x_data = static_cast(x.data()); - const int8_t* weight_data = weight.data(); - const int16_t* weight_scale_data = - static_cast(weight_scale.data()); + const int64_t* prefix_sum_ptr = prefix_sum.data(); + auto output = GetEmptyTensor({m, n}, x.dtype(), x.place()); + int16_t* out_data = static_cast(output.data()); + const int16_t* x_data = static_cast(x.data()); + const int8_t* weight_data = weight.data(); + const int16_t* weight_scale_data = + static_cast(weight_scale.data()); - cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle(); - cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; - cuinferOperation_t transa = CUINFER_OP_T; - cuinferOperation_t transb = CUINFER_OP_N; - cudaDataType_t a_type = CUDA_R_8I; - cudaDataType_t b_type; - cudaDataType_t c_type; - if (x.dtype() == paddle::DataType::FLOAT16) { - b_type = CUDA_R_16F; - } else if (x.dtype() == paddle::DataType::BFLOAT16) { - b_type = CUDA_R_16BF; - } else { - PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype.")); + cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle(); + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_N; + cudaDataType_t a_type = CUDA_R_8I; + cudaDataType_t b_type; + cudaDataType_t c_type; + if (x.dtype() == paddle::DataType::FLOAT16) { + b_type = CUDA_R_16F; + } else if (x.dtype() == paddle::DataType::BFLOAT16) { + b_type = CUDA_R_16BF; + } else { + PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype.")); + } + c_type = b_type; + cudaDataType_t Atype = a_type; + cudaDataType_t Btype = b_type; + cudaDataType_t Ctype = c_type; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + + cuinferQuantGEMMHostParam cust_host_param; + cust_host_param.size = sizeof(cuinferQuantGEMMHostParam); + cust_host_param.persistent = 0; + cust_host_param.groupSize = group_size; + cuinferQuantGEMMDeviceParam cust_device_param; + cust_device_param.bias = nullptr; + cust_device_param.workspace = nullptr; + + int lda = k; + int ldb = k; + int ldc = n; + float beta = 0.f; + float alpha = 1.f; + int batch_count = 1; + size_t pre = 0; + + auto* allocator = paddle::GetAllocator(x.place()); + phi::Allocator::AllocationPtr tmp_workspace; + for (int i = 0; i < n_experts; i++) { + size_t expert_i_end = prefix_sum_ptr[i]; + size_t cur_len = expert_i_end - pre; + pre = expert_i_end; + if (cur_len != 0) { + cust_device_param.scale = weight_scale_data; + + if (k % 64 != 0) { + size_t workspace_size; + CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa, + transb, + n, + cur_len, + k, + Atype, + lda, + lda, + Btype, + ldb, + ldb, + Ctype, + ldc, + ldc, + batch_count, + computeType, + scaleType, + &workspace_size)); + tmp_workspace = allocator->Allocate(workspace_size); + cust_device_param.workspace = tmp_workspace->ptr(); + } else { + cust_device_param.workspace = nullptr; + } + + CUINFER_CHECK(cuinferCustomGemm(handle, + stream, + cuinfer_ptr_mode, + transa, + transb, + n, + cur_len, + k, + &alpha, + weight_data, + Atype, + lda, + lda, + x_data, + Btype, + ldb, + ldb, + &beta, + out_data, + Ctype, + ldc, + ldc, + batch_count, + computeType, + scaleType, + &cust_host_param, + &cust_device_param, + customOption)); } - c_type = b_type; - cudaDataType_t Atype = a_type; - cudaDataType_t Btype = b_type; - cudaDataType_t Ctype = c_type; - cudaDataType_t computeType = CUDA_R_32F; - cudaDataType_t scaleType = CUDA_R_32F; - cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; - - cuinferQuantGEMMHostParam cust_host_param; - cust_host_param.size = sizeof(cuinferQuantGEMMHostParam); - cust_host_param.persistent = 0; - cust_host_param.groupSize = group_size; - cuinferQuantGEMMDeviceParam cust_device_param; - cust_device_param.bias = nullptr; - cust_device_param.workspace = nullptr; - - int lda = k; - int ldb = k; - int ldc = n; - float beta = 0.f; - float alpha = 1.f; - int batch_count = 1; - size_t pre = 0; - - auto* allocator = paddle::GetAllocator(x.place()); - phi::Allocator::AllocationPtr tmp_workspace; - for (int i = 0; i < n_experts; i++) { - size_t expert_i_end = prefix_sum_ptr[i]; - size_t cur_len = expert_i_end - pre; - pre = expert_i_end; - if (cur_len != 0) { - cust_device_param.scale = weight_scale_data; - - if (k % 64 != 0) { - size_t workspace_size; - CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa, - transb, - n, - cur_len, - k, - Atype, - lda, - lda, - Btype, - ldb, - ldb, - Ctype, - ldc, - ldc, - batch_count, - computeType, - scaleType, - &workspace_size)); - tmp_workspace = allocator->Allocate(workspace_size); - cust_device_param.workspace = tmp_workspace->ptr(); - } else { - cust_device_param.workspace = nullptr; - } - - CUINFER_CHECK(cuinferCustomGemm(handle, - stream, - cuinfer_ptr_mode, - transa, - transb, - n, - cur_len, - k, - &alpha, - weight_data, - Atype, - lda, - lda, - x_data, - Btype, - ldb, - ldb, - &beta, - out_data, - Ctype, - ldc, - ldc, - batch_count, - computeType, - scaleType, - &cust_host_param, - &cust_device_param, - customOption)); - } - x_data += cur_len * k; - weight_data += k * n; - weight_scale_data += n; - out_data += cur_len * n; - } - return {output}; + x_data += cur_len * k; + weight_data += k * n; + weight_scale_data += n; + out_data += cur_len * n; + } + return {output}; } std::vector> GroupGemmInferShape( @@ -178,7 +178,7 @@ std::vector> GroupGemmInferShape( const std::vector& weight_shape, const std::vector& weight_scale_shape, const std::vector& prefix_sum_shape) { - return {{x_shape[0], weight_shape[1]}}; + return {{x_shape[0], weight_shape[1]}}; } std::vector GroupGemmInferDtype( const paddle::DataType& input_dtype, @@ -186,7 +186,7 @@ std::vector GroupGemmInferDtype( const paddle::DataType& weight_scale_dtype, const paddle::DataType& prefix_sum_dtype, const int moe_topk) { - return {input_dtype}; + return {input_dtype}; } PD_BUILD_STATIC_OP(w8a16_group_gemm) diff --git a/custom_ops/metax_ops/fused_moe.cu b/custom_ops/metax_ops/fused_moe.cu index c3f2169d4..fbfaa952d 100644 --- a/custom_ops/metax_ops/fused_moe.cu +++ b/custom_ops/metax_ops/fused_moe.cu @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once +#include "fused_moe_op.h" #include "helper.h" #include "mc_fused_moe_helper.h" -#include "fused_moe_op.h" __global__ void compute_total_rows_before_expert_kernel( int* sorted_experts, @@ -43,7 +42,10 @@ void compute_total_rows_before_expert(int* sorted_indices, sorted_indices, total_indices, num_experts, total_rows_before_expert); } -template +template void FusedMoeKernel(const paddle::Tensor& input, const paddle::Tensor& gate_weight, const paddle::Tensor& ffn1_weight, @@ -63,27 +65,26 @@ void FusedMoeKernel(const paddle::Tensor& input, auto* output_data = output->data(); - auto moe_compute = McMoeHelper(quant_method); + auto moe_compute = + McMoeHelper(quant_method); - moe_compute.computeFFN( - &input, - &gate_weight, - &ffn1_weight, - ffn1_scale ? ffn1_scale.get_ptr() : nullptr, - ffn1_bias ? ffn1_bias.get_ptr() : nullptr, - &ffn2_weight, - ffn2_scale ? ffn2_scale.get_ptr() : nullptr, - ffn2_bias ? ffn2_bias.get_ptr() : nullptr, - nullptr, - moe_topk, - group_moe, - norm_topk_prob, - 1.0, // ComputeFFN - "ffn", - output); + moe_compute.computeFFN(&input, + &gate_weight, + &ffn1_weight, + ffn1_scale ? ffn1_scale.get_ptr() : nullptr, + ffn1_bias ? ffn1_bias.get_ptr() : nullptr, + &ffn2_weight, + ffn2_scale ? ffn2_scale.get_ptr() : nullptr, + ffn2_bias ? ffn2_bias.get_ptr() : nullptr, + nullptr, + moe_topk, + group_moe, + norm_topk_prob, + 1.0, // ComputeFFN + "ffn", + output); } - std::vector FusedExpertMoe( const paddle::Tensor& input, const paddle::Tensor& gate_weight, @@ -102,19 +103,22 @@ std::vector FusedExpertMoe( switch (input_type) { case paddle::DataType::BFLOAT16: - FusedMoeKernel(input, - gate_weight, - ffn1_weight, - ffn1_scale, - ffn1_bias, - ffn2_weight, - ffn2_scale, - ffn2_bias, - quant_method, - moe_topk, - group_moe, - norm_topk_prob, - &output); + FusedMoeKernel(input, + gate_weight, + ffn1_weight, + ffn1_scale, + ffn1_bias, + ffn2_weight, + ffn2_scale, + ffn2_bias, + quant_method, + moe_topk, + group_moe, + norm_topk_prob, + &output); break; // case paddle::DataType::FLOAT16: // FusedMoeKernel(input, @@ -161,7 +165,6 @@ std::vector FusedExpertMoeInferDtype( return {input_dtype}; } - PD_BUILD_OP(fused_expert_moe) .Inputs({"input", "gate_weight", diff --git a/custom_ops/metax_ops/fused_moe_imp_op.h b/custom_ops/metax_ops/fused_moe_imp_op.h index 547b4cacc..99aabaf8a 100644 --- a/custom_ops/metax_ops/fused_moe_imp_op.h +++ b/custom_ops/metax_ops/fused_moe_imp_op.h @@ -16,8 +16,8 @@ */ #pragma once -#include #include +#include #include "cub/cub.cuh" static const float HALF_FLT_MAX = 65504.F; diff --git a/custom_ops/metax_ops/fused_moe_op.h b/custom_ops/metax_ops/fused_moe_op.h index b53df12bf..00ed38115 100644 --- a/custom_ops/metax_ops/fused_moe_op.h +++ b/custom_ops/metax_ops/fused_moe_op.h @@ -19,9 +19,9 @@ #include #include -#include "fused_moe_imp_op.h" #include "fused_moe_helper.h" -#include "mctlass/numeric_conversion.h" // BUILD_MARK +#include "fused_moe_imp_op.h" +#include "mctlass/numeric_conversion.h" // BUILD_MARK // Ignore mctlass warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -35,8 +35,8 @@ #define WARP_SIZE 32 struct GpuLaunchConfig { - dim3 block_per_grid; - dim3 thread_per_block; + dim3 block_per_grid; + dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { @@ -82,7 +82,6 @@ __launch_bounds__(TPB) __global__ cub::Sum sum; float threadData(-FLT_MAX); - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData = max(static_cast(input[idx]), threadData); @@ -603,7 +602,7 @@ void topk_gating_softmax_kernelLauncher(const T* input, } static constexpr int WARPS_PER_TB = 4; - #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ +#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ case N: { \ topk_gating_softmax_launcher_helper( \ input, output, indices, source_row, num_rows, num_experts, k, stream); \ @@ -646,14 +645,8 @@ void topk_gating_softmax_kernelLauncher(const T* input, const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); moe_softmax<<>>( input, softmax, num_experts, num_rows); - moe_top_k - <<>>(softmax, - output, - indices, - source_row, - num_experts, - k, - num_rows); + moe_top_k<<>>( + softmax, output, indices, source_row, num_experts, k, num_rows); } } } diff --git a/custom_ops/metax_ops/mc_fused_moe_helper.h b/custom_ops/metax_ops/mc_fused_moe_helper.h index 525b1b97e..7cc4a18e8 100644 --- a/custom_ops/metax_ops/mc_fused_moe_helper.h +++ b/custom_ops/metax_ops/mc_fused_moe_helper.h @@ -1,52 +1,71 @@ +#include "fused_moe_helper.h" #include "mctlass/numeric_conversion.h" #include "mctlassEx/mctlassEx.h" -#include "fused_moe_helper.h" - template -void mc_grouped_gemm_basic_kernel( - const ElementA* ptrA, - mctlassExOrder_t majorA, - const ElementB* ptrB, - mctlassExOrder_t majorB, - const ElementA* ptrScale, - const ElementA* ptrBias, - ElementC* ptrC, - mctlassExOrder_t majorC, - const int *ptrSegInd, - int numExperts, - int m, // expanded_active_expert_rows - int n, // inter_dim - int k, // hidden_size - mcStream_t stream) { +void mc_grouped_gemm_basic_kernel(const ElementA *ptrA, + mctlassExOrder_t majorA, + const ElementB *ptrB, + mctlassExOrder_t majorB, + const ElementA *ptrScale, + const ElementA *ptrBias, + ElementC *ptrC, + mctlassExOrder_t majorC, + const int *ptrSegInd, + int numExperts, + int m, // expanded_active_expert_rows + int n, // inter_dim + int k, // hidden_size + mcStream_t stream) { mctlassExHandle_t handle; mctlassExHandleCreate(&handle); - int* ptrMNumTilesInd; - mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream); + int *ptrMNumTilesInd; + mcMallocAsync((void **)&ptrMNumTilesInd, sizeof(int) * numExperts, stream); mctlassExMatrixLayout_t matLayoutA; mctlassExMatrixLayout_t matLayoutB; mctlassExMatrixLayout_t matLayoutC; // mat A: (m, k) - mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k); - mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, - &majorA, sizeof(mctlassExOrder_t)); - mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, - &numExperts, sizeof(int)); + mctlassExMatrixLayoutCreate( + &matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k); + mctlassExMatrixLayoutSetAttribute( + matLayoutA, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorA, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutA, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); // mat B: (num_experts, n, k) - mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k); - mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, - &majorB, sizeof(mctlassExOrder_t)); - mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, - &numExperts, sizeof(int)); + mctlassExMatrixLayoutCreate( + &matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k); + mctlassExMatrixLayoutSetAttribute( + matLayoutB, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorB, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutB, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); // mat C: (m, n) - mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n); - mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, - &majorC, sizeof(mctlassExOrder_t)); - mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, - &numExperts, sizeof(int)); + mctlassExMatrixLayoutCreate( + &matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n); + mctlassExMatrixLayoutSetAttribute( + matLayoutC, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorC, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutC, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); // bias: (num_experts, n) // scale: (num, n) @@ -55,44 +74,81 @@ void mc_grouped_gemm_basic_kernel( mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16; mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8; mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32; - mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT; + mctlassExEpilogueType epilogue_type = + mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT; if (ptrBias) { epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS; } // set scale - mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER, - &ptrScale, sizeof(ptrScale)); - mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE, - &input_type, sizeof(mctlassExDataType)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER, + &ptrScale, + sizeof(ptrScale)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE, + &input_type, + sizeof(mctlassExDataType)); // set bias if (ptrBias) { - mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER, - &ptrBias, sizeof(ptrBias)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER, + &ptrBias, + sizeof(ptrBias)); } // set coumpute type - mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE, - &compute_type, sizeof(mctlassExDataType)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE, + &compute_type, + sizeof(mctlassExDataType)); // set epilogue type - mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE, - &epilogue_type, sizeof(mctlassExEpilogueType)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE, + &epilogue_type, + sizeof(mctlassExEpilogueType)); - const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT; + const mctlassExContiguousGroupedGemmAlgo_t algo = + mctlassExContiguousGroupedGemmAlgo_t:: + MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT; mctlassExContiguousGroupedDesc_t contiguous_group_desc; - mctlassExContiguousGroupedDescCreate(&contiguous_group_desc, - ptrSegInd, - nullptr, - ptrMNumTilesInd, - 1); + mctlassExContiguousGroupedDescCreate( + &contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1); int blocksizeM; - mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM); - mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream); + mctlassExContiguousGroupedGemmGetBlocksizeM(handle, + mctlass_desc, + matLayoutA, + matLayoutB, + matLayoutC, + &algo, + &blocksizeM); + mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, + mctlass_desc, + matLayoutA, + matLayoutB, + matLayoutC, + &algo, + contiguous_group_desc, + numExperts, + blocksizeM, + stream); - mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc, - ptrA, matLayoutA, - ptrB, matLayoutB, - ptrC, matLayoutC, + mctlassExContiguousGroupedGemmBasic(handle, + mctlass_desc, + ptrA, + matLayoutA, + ptrB, + matLayoutB, + ptrC, + matLayoutC, contiguous_group_desc, - &algo, nullptr, 0, stream); + &algo, + nullptr, + 0, + stream); mctlassExHandleDestroy(handle); mctlassExMatrixLayoutDestroy(matLayoutA); @@ -103,312 +159,312 @@ void mc_grouped_gemm_basic_kernel( mcFreeAsync(ptrMNumTilesInd, stream); } -template +template class McMoeHelper { - public: - McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {} + public: + McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {} - // -------- getWorkspaceSize -------- // - template - size_t getWorkspaceSize(const int64_t num_rows, - const int64_t hidden_size, - const int64_t inter_size, - const int64_t num_experts, - const int64_t k) { - const size_t buf_size = AlignTo16(k * num_rows * hidden_size); - const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); - const size_t padded_experts = AlignTo16(num_experts); - const size_t num_moe_inputs = AlignTo16(k * num_rows); - // softmax output, permuted_rows and permuted_experts have moved to outside - // of moe kernel, allocate them in Encoder or Decoder before invoking - // FfnLayer forward. - size_t total_ws_bytes = - 5 * num_moe_inputs * - sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ - total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data - total_ws_bytes += - padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_ + // -------- getWorkspaceSize -------- // + template + size_t getWorkspaceSize(const int64_t num_rows, + const int64_t hidden_size, + const int64_t inter_size, + const int64_t num_experts, + const int64_t k) { + const size_t buf_size = AlignTo16(k * num_rows * hidden_size); + const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const size_t padded_experts = AlignTo16(num_experts); + const size_t num_moe_inputs = AlignTo16(k * num_rows); + // softmax output, permuted_rows and permuted_experts have moved to outside + // of moe kernel, allocate them in Encoder or Decoder before invoking + // FfnLayer forward. + size_t total_ws_bytes = + 5 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + total_ws_bytes += + padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_ - const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); - const size_t sorter_ws_size_bytes = - AlignTo16(sorter_.getWorkspaceSize(num_rows)); - sorter_.update_num_experts(num_experts); + const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const size_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(num_experts); - int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; - if (sorter_ws_size_bytes > bytes_for_fc1_result) { - int64_t remaining_bytes = - AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); - bytes_for_intermediate_and_sorting += remaining_bytes; - } - - total_ws_bytes += - bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub - // sorting workspace - - int64_t num_softmax_outs = 0; - const bool is_pow_2 = - (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - num_softmax_outs = AlignTo16(num_rows * num_experts); - } - - total_ws_bytes += num_softmax_outs * sizeof(float); - - return total_ws_bytes; + int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int64_t remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; } - void computeFFN(const paddle::Tensor *input, - const paddle::Tensor *gate_weight, - const paddle::Tensor *ffn1_weight, - const paddle::Tensor *ffn1_scale, - const paddle::Tensor *ffn1_bias, - const paddle::Tensor *ffn2_weight, - const paddle::Tensor *ffn2_scale, - const paddle::Tensor *ffn2_bias, - const paddle::Tensor *moe_token_type_ids, - const int moe_topk, - const bool group_moe, - const bool norm_topk_prob, - const float routed_scaling_factor, - const std::string moe_type, - paddle::Tensor *output) { - auto *input_activations = input->data(); - auto *gating_weights = gate_weight->data(); - const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; - const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; + total_ws_bytes += + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace - auto *output_ = output->data(); - auto stream = input->stream(); - auto place = input->place(); - auto input_type = input->dtype(); + int64_t num_softmax_outs = 0; + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = AlignTo16(num_rows * num_experts); + } - auto input_dims = input->dims(); - auto ffn1_dims = ffn1_weight->dims(); - int64_t token_num = 0; - if (input_dims.size() == 3) { - token_num = input_dims[0] * input_dims[1]; - } else { - token_num = input_dims[0]; - } - const int64_t num_rows = token_num; + total_ws_bytes += num_softmax_outs * sizeof(float); - const int64_t hidden_size = ffn1_dims[2]; - int64_t inter_dim = 0; - if (moe_type == "qkv") { - inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4]; - } else { - inter_dim = ffn1_dims[1]; - } + return total_ws_bytes; + } - // if (gemm_method == "weight_only_int4") { - // inter_dim = inter_dim * 2; - // } + void computeFFN(const paddle::Tensor *input, + const paddle::Tensor *gate_weight, + const paddle::Tensor *ffn1_weight, + const paddle::Tensor *ffn1_scale, + const paddle::Tensor *ffn1_bias, + const paddle::Tensor *ffn2_weight, + const paddle::Tensor *ffn2_scale, + const paddle::Tensor *ffn2_bias, + const paddle::Tensor *moe_token_type_ids, + const int moe_topk, + const bool group_moe, + const bool norm_topk_prob, + const float routed_scaling_factor, + const std::string moe_type, + paddle::Tensor *output) { + auto *input_activations = input->data(); + auto *gating_weights = gate_weight->data(); + const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data() : nullptr; + const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data() : nullptr; - const int64_t inter_size = inter_dim; - const int64_t num_experts = ffn1_dims[0]; - const int64_t k = moe_topk; + auto *output_ = output->data(); + auto stream = input->stream(); + auto place = input->place(); + auto input_type = input->dtype(); + auto input_dims = input->dims(); + auto ffn1_dims = ffn1_weight->dims(); + int64_t token_num = 0; + if (input_dims.size() == 3) { + token_num = input_dims[0] * input_dims[1]; + } else { + token_num = input_dims[0]; + } + const int64_t num_rows = token_num; - int64_t bytes = - getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); + const int64_t hidden_size = ffn1_dims[2]; + int64_t inter_dim = 0; + if (moe_type == "qkv") { + inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4]; + } else { + inter_dim = ffn1_dims[1]; + } - // Pointers - int *expert_for_source_row; - int *source_rows_; - int *permuted_rows_; - int *permuted_experts_; - int *expanded_source_row_to_expanded_dest_row; + // if (gemm_method == "weight_only_int4") { + // inter_dim = inter_dim * 2; + // } - T *permuted_data_; - int32_t *total_rows_before_expert_; - T *fc1_result_; - float *softmax_out_; + const int64_t inter_size = inter_dim; + const int64_t num_experts = ffn1_dims[0]; + const int64_t k = moe_topk; - paddle::Tensor ws_ptr_tensor = - GetEmptyTensor({bytes}, paddle::DataType::INT8, place); - int8_t *ws_ptr = ws_ptr_tensor.data(); + int64_t bytes = + getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); - const int64_t buf_size = AlignTo16(k * num_rows * hidden_size); - const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size); - const int64_t padded_experts = AlignTo16(num_experts); - const int64_t num_moe_inputs = AlignTo16(k * num_rows); + // Pointers + int *expert_for_source_row; + int *source_rows_; + int *permuted_rows_; + int *permuted_experts_; + int *expanded_source_row_to_expanded_dest_row; - expert_for_source_row = reinterpret_cast(ws_ptr); - source_rows_ = expert_for_source_row + num_moe_inputs; - permuted_rows_ = source_rows_ + num_moe_inputs; - permuted_experts_ = permuted_rows_ + num_moe_inputs; - expanded_source_row_to_expanded_dest_row = - permuted_experts_ + num_moe_inputs; - permuted_data_ = reinterpret_cast( - expanded_source_row_to_expanded_dest_row + num_moe_inputs); - total_rows_before_expert_ = - reinterpret_cast(permuted_data_ + buf_size); - fc1_result_ = - reinterpret_cast(total_rows_before_expert_ + padded_experts); + T *permuted_data_; + int32_t *total_rows_before_expert_; + T *fc1_result_; + float *softmax_out_; - const bool is_pow_2 = - (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); - } else { - softmax_out_ = nullptr; - } + paddle::Tensor ws_ptr_tensor = + GetEmptyTensor({bytes}, paddle::DataType::INT8, place); + int8_t *ws_ptr = ws_ptr_tensor.data(); - paddle::Tensor expert_scales_float_tensor = - GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); - float *expert_scales_float = expert_scales_float_tensor.data(); + const int64_t buf_size = AlignTo16(k * num_rows * hidden_size); + const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const int64_t padded_experts = AlignTo16(num_experts); + const int64_t num_moe_inputs = AlignTo16(k * num_rows); - float *softmax_max_prob = nullptr; - if (group_moe) { - paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor( - {num_rows, moe_topk}, paddle::DataType::FLOAT32, place); - // (TODO: check fill success ?) - paddle::experimental::fill(softmax_max_prob_tensor, 0.f); - softmax_max_prob = softmax_max_prob_tensor.data(); - } + expert_for_source_row = reinterpret_cast(ws_ptr); + source_rows_ = expert_for_source_row + num_moe_inputs; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expanded_source_row_to_expanded_dest_row = + permuted_experts_ + num_moe_inputs; + permuted_data_ = reinterpret_cast( + expanded_source_row_to_expanded_dest_row + num_moe_inputs); + total_rows_before_expert_ = + reinterpret_cast(permuted_data_ + buf_size); + fc1_result_ = + reinterpret_cast(total_rows_before_expert_ + padded_experts); - paddle::Tensor fc1_out_tensor = - GetEmptyTensor({num_rows * k, inter_size}, input_type, place); - T *fc1_out = fc1_out_tensor.data(); + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } - auto input_cast_tensor = - paddle::experimental::cast(*input, paddle::DataType::FLOAT32); - auto gate_tensor = - paddle::experimental::matmul(input_cast_tensor, *gate_weight); - float *gating_output = gate_tensor.data(); + paddle::Tensor expert_scales_float_tensor = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + float *expert_scales_float = expert_scales_float_tensor.data(); - if (moe_token_type_ids) { - auto *moe_token_type_ids_out = moe_token_type_ids->data(); - moe_token_type_ids_kernelLauncher(gating_output, - moe_token_type_ids_out, - num_rows, - num_experts, - k, - stream); - } + float *softmax_max_prob = nullptr; + if (group_moe) { + paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor( + {num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + // (TODO: check fill success ?) + paddle::experimental::fill(softmax_max_prob_tensor, 0.f); + softmax_max_prob = softmax_max_prob_tensor.data(); + } - topk_gating_softmax_kernelLauncher(gating_output, - expert_scales_float, - softmax_out_, - expert_for_source_row, - source_rows_, - softmax_max_prob, - num_rows, - num_experts, - k, - group_moe, - stream); + paddle::Tensor fc1_out_tensor = + GetEmptyTensor({num_rows * k, inter_size}, input_type, place); + T *fc1_out = fc1_out_tensor.data(); - const int64_t sorter_ws_size_bytes = - AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); + auto input_cast_tensor = + paddle::experimental::cast(*input, paddle::DataType::FLOAT32); + auto gate_tensor = + paddle::experimental::matmul(input_cast_tensor, *gate_weight); + float *gating_output = gate_tensor.data(); - sorter_.run(fc1_result_, - sorter_ws_size_bytes, - expert_for_source_row, - permuted_experts_, - source_rows_, - permuted_rows_, - k * num_rows, - false, - stream); + if (moe_token_type_ids) { + auto *moe_token_type_ids_out = moe_token_type_ids->data(); + moe_token_type_ids_kernelLauncher(gating_output, + moe_token_type_ids_out, + num_rows, + num_experts, + k, + stream); + } - initialize_moe_routing_kernelLauncher( - input_activations, - permuted_data_, - permuted_rows_, - expanded_source_row_to_expanded_dest_row, - num_rows, - num_rows, - hidden_size, - k, - stream); + topk_gating_softmax_kernelLauncher(gating_output, + expert_scales_float, + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + num_experts, + k, + group_moe, + stream); - const int64_t expanded_active_expert_rows = k * num_rows; + const int64_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); - compute_total_rows_before_expert(permuted_experts_, - expanded_active_expert_rows, - num_experts, - total_rows_before_expert_, - stream); + sorter_.run(fc1_result_, + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + k * num_rows, + false, + stream); - mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; - mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + initialize_moe_routing_kernelLauncher( + input_activations, + permuted_data_, + permuted_rows_, + expanded_source_row_to_expanded_dest_row, + num_rows, + num_rows, + hidden_size, + k, + stream); + + const int64_t expanded_active_expert_rows = k * num_rows; + + compute_total_rows_before_expert(permuted_experts_, + expanded_active_expert_rows, + num_experts, + total_rows_before_expert_, + stream); + + mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; + mctlassExOrder_t column_major = + mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + + mc_grouped_gemm_basic_kernel( + reinterpret_cast(permuted_data_), + row_major, + reinterpret_cast(ffn1_weight->data()), + column_major, + reinterpret_cast(ffn1_scale->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + row_major, + total_rows_before_expert_, + num_experts, + expanded_active_expert_rows, + inter_size, + hidden_size, + stream); + + if (moe_type == "ffn") { + auto act_out_tensor = + paddle::experimental::swiglu(fc1_out_tensor, nullptr); + auto act_out = act_out_tensor.data(); + + paddle::Tensor fc2_output_tensor = + GetEmptyTensor({k * num_rows, hidden_size}, input_type, place); + T *fc2_result = fc2_output_tensor.data(); mc_grouped_gemm_basic_kernel( - reinterpret_cast(permuted_data_), + reinterpret_cast(act_out), row_major, - reinterpret_cast(ffn1_weight->data()), + reinterpret_cast(ffn2_weight->data()), column_major, - reinterpret_cast(ffn1_scale->data()), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), + reinterpret_cast(ffn2_scale->data()), + nullptr, + reinterpret_cast(fc2_result), row_major, total_rows_before_expert_, num_experts, expanded_active_expert_rows, - inter_size, hidden_size, + inter_size / 2, stream); - if (moe_type == "ffn") { - auto act_out_tensor = - paddle::experimental::swiglu(fc1_out_tensor, nullptr); - auto act_out = act_out_tensor.data(); - - paddle::Tensor fc2_output_tensor = - GetEmptyTensor({k * num_rows, hidden_size}, input_type, place); - T *fc2_result = fc2_output_tensor.data(); - - mc_grouped_gemm_basic_kernel( - reinterpret_cast(act_out), - row_major, - reinterpret_cast(ffn2_weight->data()), - column_major, - reinterpret_cast(ffn2_scale->data()), - nullptr, - reinterpret_cast(fc2_result), - row_major, - total_rows_before_expert_, - num_experts, - expanded_active_expert_rows, - hidden_size, - inter_size / 2, - stream); - - finalize_moe_routing_kernelLauncher( - fc2_result, - output_, - fc2_expert_biases, - reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, - expert_for_source_row, - num_rows, - hidden_size, - k, - static_cast(1), - norm_topk_prob, - routed_scaling_factor, - stream); - } else { - finalize_moe_routing_kernelLauncher( - // fc2_result, - fc1_out, - output_, - fc1_expert_biases, // fc2_expert_biases, - reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, - expert_for_source_row, - num_rows, - inter_size, - k, - static_cast(0), - norm_topk_prob, - routed_scaling_factor, - stream); - } + finalize_moe_routing_kernelLauncher( + fc2_result, + output_, + fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + hidden_size, + k, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); + } else { + finalize_moe_routing_kernelLauncher( + // fc2_result, + fc1_out, + output_, + fc1_expert_biases, // fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + inter_size, + k, + static_cast(0), + norm_topk_prob, + routed_scaling_factor, + stream); } + } -private: + private: std::string gemm_method_; CubKeyValueSorter sorter_; }; diff --git a/custom_ops/metax_ops/moe_dispatch.cu b/custom_ops/metax_ops/moe_dispatch.cu index e855666e0..e62d1746e 100644 --- a/custom_ops/metax_ops/moe_dispatch.cu +++ b/custom_ops/metax_ops/moe_dispatch.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wunused-function" @@ -24,7 +23,6 @@ #include "helper.h" - template void MoeDispatchKernel(const paddle::Tensor& input, const paddle::Tensor& gating_output, @@ -128,7 +126,6 @@ void MoeDispatchKernel(const paddle::Tensor& input, false, stream); - initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), @@ -140,16 +137,13 @@ void MoeDispatchKernel(const paddle::Tensor& input, moe_topk, stream); - - compute_total_rows_before_expert( - permuted_experts_, - moe_topk * num_rows, - expert_num, - tokens_expert_prefix_sum->data(), - stream); + compute_total_rows_before_expert(permuted_experts_, + moe_topk * num_rows, + expert_num, + tokens_expert_prefix_sum->data(), + stream); } - std::vector MoeExpertDispatch( const paddle::Tensor& input, const paddle::Tensor& gating_output, @@ -184,7 +178,6 @@ std::vector MoeExpertDispatch( auto permute_indices_per_token = GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); - switch (input_type) { case paddle::DataType::BFLOAT16: MoeDispatchKernel(input, @@ -226,7 +219,6 @@ std::vector MoeExpertDispatch( top_k_indices}; } - std::vector> MoeExpertDispatchInferShape( const std::vector& input_shape, const std::vector& gating_output_shape, @@ -260,7 +252,6 @@ std::vector MoeExpertDispatchInferDtype( paddle::DataType::INT32}; } - PD_BUILD_OP(moe_expert_dispatch) .Inputs({"input", "gating_output"}) .Outputs({"permute_input", diff --git a/custom_ops/metax_ops/moe_ffn.cu b/custom_ops/metax_ops/moe_ffn.cu index daf9aec0d..bc268f769 100644 --- a/custom_ops/metax_ops/moe_ffn.cu +++ b/custom_ops/metax_ops/moe_ffn.cu @@ -14,19 +14,22 @@ // BUILD_MARK #pragma once -#include "mc_fused_moe_helper.h" #include "helper.h" +#include "mc_fused_moe_helper.h" -template +template void McMoeFFNKernel(const paddle::Tensor& permute_input, - const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_scale, - const std::string& quant_method, - paddle::Tensor ffn_out) { + const paddle::Tensor& tokens_expert_prefix_sum, + const paddle::Tensor& ffn1_weight, + const paddle::Tensor& ffn2_weight, + const paddle::optional& ffn1_bias, + const paddle::optional& ffn1_scale, + const paddle::optional& ffn2_scale, + const std::string& quant_method, + paddle::Tensor ffn_out) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -37,61 +40,65 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input, auto input_type = permute_input.dtype(); auto stream = permute_input.stream(); - const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k - const int num_experts = ffn1_weight.dims()[0]; // batchsize - const int hidden_size = ffn1_weight.dims()[2]; // n - int inter_dim = ffn1_weight.dims()[1]; // k + const int expanded_active_expert_rows = + permute_input.dims()[0]; // permute_input.dims(): m, k + const int num_experts = ffn1_weight.dims()[0]; // batchsize + const int hidden_size = ffn1_weight.dims()[2]; // n + int inter_dim = ffn1_weight.dims()[1]; // k - const int64_t inter_size = inter_dim; // since weight_only_int_8 + const int64_t inter_size = inter_dim; // since weight_only_int_8 paddle::Tensor fc1_out_tensor = GetEmptyTensor( {expanded_active_expert_rows, inter_size}, input_type, place); auto fc1_out_ptr = fc1_out_tensor.data(); mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; - mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + mctlassExOrder_t column_major = + mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; // ffn1 auto fc1_expert_biases = - ffn1_bias - ? const_cast(ffn1_bias.get_ptr())->data() - : nullptr; - auto fc1_expert_scales = const_cast(ffn1_scale.get_ptr())->data(); + ffn1_bias + ? const_cast(ffn1_bias.get_ptr())->data() + : nullptr; + auto fc1_expert_scales = + const_cast(ffn1_scale.get_ptr())->data(); mc_grouped_gemm_basic_kernel( - reinterpret_cast(permuted_input_ptr), - row_major, - reinterpret_cast(ffn1_weight.data()), - column_major, - reinterpret_cast(fc1_expert_scales), - reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out_ptr), - row_major, - tokens_expert_prefix_sum.data(), - num_experts, - expanded_active_expert_rows, - inter_dim, - hidden_size, - stream); + reinterpret_cast(permuted_input_ptr), + row_major, + reinterpret_cast(ffn1_weight.data()), + column_major, + reinterpret_cast(fc1_expert_scales), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out_ptr), + row_major, + tokens_expert_prefix_sum.data(), + num_experts, + expanded_active_expert_rows, + inter_dim, + hidden_size, + stream); // swiglu auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); auto act_out = act_out_tensor.data(); - auto fc2_expert_scales = const_cast(ffn2_scale.get_ptr())->data(); + auto fc2_expert_scales = + const_cast(ffn2_scale.get_ptr())->data(); mc_grouped_gemm_basic_kernel( - reinterpret_cast(act_out), - row_major, - reinterpret_cast(ffn2_weight.data()), - column_major, - reinterpret_cast(fc2_expert_scales), - nullptr, - reinterpret_cast(ffn_out_ptr), - row_major, - tokens_expert_prefix_sum.data(), - num_experts, - expanded_active_expert_rows, - hidden_size, - inter_dim / 2, - stream); + reinterpret_cast(act_out), + row_major, + reinterpret_cast(ffn2_weight.data()), + column_major, + reinterpret_cast(fc2_expert_scales), + nullptr, + reinterpret_cast(ffn_out_ptr), + row_major, + tokens_expert_prefix_sum.data(), + num_experts, + expanded_active_expert_rows, + hidden_size, + inter_dim / 2, + stream); } std::vector MoeExpertFFN( @@ -109,15 +116,18 @@ std::vector MoeExpertFFN( switch (input_type) { case paddle::DataType::BFLOAT16: - McMoeFFNKernel(permute_input, - tokens_expert_prefix_sum, - ffn1_weight, - ffn2_weight, - ffn1_bias, - ffn1_scale, - ffn2_scale, - quant_method, - ffn_out); + McMoeFFNKernel(permute_input, + tokens_expert_prefix_sum, + ffn1_weight, + ffn2_weight, + ffn1_bias, + ffn1_scale, + ffn2_scale, + quant_method, + ffn_out); break; // case paddle::DataType::FLOAT16: // MoeFFNKernel(permute_input, diff --git a/custom_ops/metax_ops/moe_reduce.cu b/custom_ops/metax_ops/moe_reduce.cu index be9e84ce7..7ec694215 100644 --- a/custom_ops/metax_ops/moe_reduce.cu +++ b/custom_ops/metax_ops/moe_reduce.cu @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once -#include "helper.h" #include "fused_moe_helper.h" #include "fused_moe_op.h" +#include "helper.h" template void MoeReduceKernel(const paddle::Tensor& ffn_out, @@ -52,7 +51,6 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out, stream); } - std::vector MoeExpertReduce( const paddle::Tensor& ffn_out, const paddle::Tensor& top_k_weight, @@ -106,7 +104,6 @@ std::vector MoeExpertReduce( return {output}; } - std::vector> MoeExpertReduceInferShape( const std::vector& ffn_out_shape, const std::vector& top_k_weight_shape, @@ -129,7 +126,6 @@ std::vector MoeExpertReduceInferDtype( return {ffn_out_dtype}; } - PD_BUILD_OP(moe_expert_reduce) .Inputs({"ffn_out", "top_k_weight", diff --git a/custom_ops/xpu_ops/src/ops/adjust_batch.cc b/custom_ops/xpu_ops/src/ops/adjust_batch.cc index d087a0910..d263d2cae 100644 --- a/custom_ops/xpu_ops/src/ops/adjust_batch.cc +++ b/custom_ops/xpu_ops/src/ops/adjust_batch.cc @@ -12,64 +12,69 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "paddle/phi/core/enforce.h" #include "utility/helper.h" #include "xpu/plugin.h" -#include template -std::vector -AdjustBatchKernel(const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &encoder_seq_lod, - const paddle::Tensor &encoder_batch_idx, - const paddle::Tensor &decoder_batch_idx, - const paddle::Tensor &encoder_seq_lod_cpu, - const paddle::Tensor &encoder_batch_idx_cpu, - const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, - const paddle::optional &output_padding_offset, - int max_input_length) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - PD_CHECK(x.dtype() == T); - PD_CHECK(x.dims().size() == 2); +std::vector AdjustBatchKernel( + const paddle::Tensor &x, // [token_num, dim_embed] + const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &encoder_batch_idx, + const paddle::Tensor &decoder_batch_idx, + const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &encoder_batch_idx_cpu, + const paddle::Tensor &decoder_batch_idx_cpu, + const paddle::Tensor &enc_batch_tensor, + const paddle::Tensor &dec_batch_tensor, + const paddle::optional &output_padding_offset, + int max_input_length) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + PD_CHECK(x.dtype() == T); + PD_CHECK(x.dims().size() == 2); - using XPUType = typename XPUTypeTrait::DataType>::Type; - using data_t = typename PDTraits::data_t; - const int token_num = x.dims()[0]; - const int dim = x.dims()[1]; - const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; + using XPUType = typename XPUTypeTrait::DataType>::Type; + using data_t = typename PDTraits::data_t; + const int token_num = x.dims()[0]; + const int dim = x.dims()[1]; + const int bsz = cum_offsets.shape()[0]; + int enc_batch = enc_batch_tensor.data()[0]; + int dec_batch = dec_batch_tensor.data()[0]; - baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ - const_cast(encoder_seq_lod_cpu.data()), - enc_batch + 1, const_cast(encoder_seq_lod.data())}; - baidu::xpu::api::VectorParam encoder_batch_map_vp{ - const_cast(encoder_batch_idx_cpu.data()), enc_batch, - const_cast(encoder_batch_idx.data())}; - baidu::xpu::api::VectorParam decoder_batch_map_vp{ - const_cast(decoder_batch_idx_cpu.data()), dec_batch, - const_cast(decoder_batch_idx.data())}; + baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ + const_cast(encoder_seq_lod_cpu.data()), + enc_batch + 1, + const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam encoder_batch_map_vp{ + const_cast(encoder_batch_idx_cpu.data()), + enc_batch, + const_cast(encoder_batch_idx.data())}; + baidu::xpu::api::VectorParam decoder_batch_map_vp{ + const_cast(decoder_batch_idx_cpu.data()), + dec_batch, + const_cast(decoder_batch_idx.data())}; - auto out = paddle::full({token_num, dim}, -2, x.type(), x.place()); + auto out = paddle::full({token_num, dim}, -2, x.type(), x.place()); - int r = baidu::xpu::api::plugin::eb_adjust_batch( - xpu_ctx->x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out.data()), encoder_seqs_lods_vp, - encoder_batch_map_vp, decoder_batch_map_vp, dim); - return {out}; + int r = baidu::xpu::api::plugin::eb_adjust_batch( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + return {out}; } using AdjustBatchKernelFuncPtr = std::vector (*)( - const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &x, // [token_num, dim_embed] + const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, @@ -81,42 +86,50 @@ using AdjustBatchKernelFuncPtr = std::vector (*)( const paddle::optional &output_padding_offset, int max_input_length); -std::vector -AdjustBatch(const paddle::Tensor &x, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &encoder_seq_lod, - const paddle::Tensor &encoder_batch_idx, - const paddle::Tensor &decoder_batch_idx, - const paddle::Tensor &encoder_seq_lod_cpu, - const paddle::Tensor &encoder_batch_idx_cpu, - const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, - const paddle::optional &output_padding_offset, - int max_input_length) { - AdjustBatchKernelFuncPtr func = nullptr; +std::vector AdjustBatch( + const paddle::Tensor &x, // [token_num, dim_embed] + const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &encoder_batch_idx, + const paddle::Tensor &decoder_batch_idx, + const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &encoder_batch_idx_cpu, + const paddle::Tensor &decoder_batch_idx_cpu, + const paddle::Tensor &enc_batch_tensor, + const paddle::Tensor &dec_batch_tensor, + const paddle::optional &output_padding_offset, + int max_input_length) { + AdjustBatchKernelFuncPtr func = nullptr; - switch (x.dtype()) { + switch (x.dtype()) { case paddle::DataType::BFLOAT16: - func = &AdjustBatchKernel; - break; + func = &AdjustBatchKernel; + break; case paddle::DataType::FLOAT16: - func = &AdjustBatchKernel; - break; + func = &AdjustBatchKernel; + break; case paddle::DataType::FLOAT32: - func = &AdjustBatchKernel; - break; + func = &AdjustBatchKernel; + break; case paddle::DataType::INT64: - func = &AdjustBatchKernel; - break; + func = &AdjustBatchKernel; + break; default: - PD_THROW("Unsupported data type: ", x.dtype()); - } + PD_THROW("Unsupported data type: ", x.dtype()); + } - return func(x, cum_offsets, encoder_seq_lod, encoder_batch_idx, - decoder_batch_idx, encoder_seq_lod_cpu, encoder_batch_idx_cpu, - decoder_batch_idx_cpu, enc_batch_tensor, dec_batch_tensor, - output_padding_offset, max_input_length); + return func(x, + cum_offsets, + encoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + enc_batch_tensor, + dec_batch_tensor, + output_padding_offset, + max_input_length); } std::vector> AdjustBatchInferShape( @@ -131,16 +144,17 @@ std::vector> AdjustBatchInferShape( const std::vector &enc_batch_tensor_shape, const std::vector &dec_batch_tensor_shape, const paddle::optional> &output_padding_offset_shape) { - if (output_padding_offset_shape) { - PD_THROW("speculative decoding is not supported in XPU."); - } - int64_t token_num = x_shape[0]; - int64_t dim_embed = x_shape[1]; - return {{token_num, dim_embed}}; + if (output_padding_offset_shape) { + PD_THROW("speculative decoding is not supported in XPU."); + } + int64_t token_num = x_shape[0]; + int64_t dim_embed = x_shape[1]; + return {{token_num, dim_embed}}; } std::vector AdjustBatchInferDtype( - const paddle::DataType &x_dtype, const paddle::DataType &cum_offsets_dtype, + const paddle::DataType &x_dtype, + const paddle::DataType &cum_offsets_dtype, const paddle::DataType &encoder_seq_lod_dtype, const paddle::DataType &encoder_batch_idx_dtype, const paddle::DataType &decoder_batch_idx_dtype, @@ -150,14 +164,20 @@ std::vector AdjustBatchInferDtype( const paddle::DataType &enc_batch_tensor_dtype, const paddle::DataType &dec_batch_tensor_dtype, const paddle::optional &output_padding_offset_dtype) { - return {x_dtype}; + return {x_dtype}; } PD_BUILD_OP(adjust_batch) - .Inputs({"x", "cum_offsets", "encoder_seq_lod", "encoder_batch_idx", - "decoder_batch_idx", "encoder_seq_lod_cpu", - "encoder_batch_idx_cpu", "decoder_batch_idx_cpu", - "enc_batch_tensor", "dec_batch_tensor", + .Inputs({"x", + "cum_offsets", + "encoder_seq_lod", + "encoder_batch_idx", + "decoder_batch_idx", + "encoder_seq_lod_cpu", + "encoder_batch_idx_cpu", + "decoder_batch_idx_cpu", + "enc_batch_tensor", + "dec_batch_tensor", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 72ae24749..f6ade82b1 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -89,7 +89,7 @@ std::vector BlockAttnKernel( const paddle::optional& smooth, const paddle::optional& kv_signal_data_cpu, const paddle::optional& cachekv_signal_thread_cpu, - const std::string &pos_emb_type, + const std::string& pos_emb_type, bool rope_3d) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); @@ -215,8 +215,8 @@ std::vector BlockAttnKernel( param.prefill_len = is_prefix_cache ? param.max_valid_seqlen : -1; param.page_attn.block_size = block_size; param.page_attn.max_num_blocks_per_seq = prefix_block_num_per_seq; - // prefix_block_tables is a subset of block_tables, which is used for prefix - // cache + // prefix_block_tables is a subset of block_tables, which is used for + // prefix cache xftblock::Tensor prefix_block_tables_tensor( is_prefix_cache ? reinterpret_cast(const_cast( prefix_block_tables.data())) @@ -306,12 +306,12 @@ std::vector BlockAttnKernel( reinterpret_cast(key_cache.data())), const_cast(reinterpret_cast( value_cache.data())), - vsl.usual_lod_vp, // seq_lod - vsl.slot_mapping_vp, // real_batch - prefix_lens_vp, // start_tokens - param.batch_size, // batch_size - 1, // emb_batch_size - rope_max_seqlen, // max_seqlen + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + prefix_lens_vp, // start_tokens + param.batch_size, // batch_size + 1, // emb_batch_size + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, @@ -480,7 +480,8 @@ std::vector BlockAttnKernel( api::VectorParam decoder_context_len_vp = { const_cast(decoder_context_len_cpu.data()), dec_batch, - nullptr}; // use for speculative_attention_decoder seq_len in MTP + nullptr}; // use for speculative_attention_decoder seq_len in + // MTP api::VectorParam decoder_context_len_cache_vp = { const_cast(decoder_context_len_cache_cpu.data()), dec_batch, @@ -597,49 +598,49 @@ std::vector BlockAttnKernel( tfloat32, int8_wo_t>; constexpr int quant_mode = std::is_same_v ? 3 : 0; - ret = baidu::xpu::xfa::speculative_attention_decoder( - xpu_ctx->x_context(), - decode_output_ptr, // out - q_buf_ptr, // q - nullptr, // k - nullptr, // v - reinterpret_cast( - key_cache.data()), // k_cache - reinterpret_cast( - value_cache.data()), // v_cache - reinterpret_cast( - block_tables.data()), // block_tables - decoder_context_len_vp, // seq_lengths - decoder_batch_map_vp, // valid_batch - param.max_batch_size, // batch_num - q_len, // qlen - max_seq_len, // max_seq_len - param.head_num, // head_num - param.head_dim, // head_dim - param.kv_head_num, // kv_head_num - nullptr, // attn_mask - 1.0f / - std::sqrt(static_cast(param.head_dim)), // scale 【check】 - block_size, // block_size - max_block_per_seq, // max_blocks_per_seq - -1, // max_window_size - nullptr, // q_maxptr - has_zp // k_cache_maxptr - ? fake_perhead_scale - : quant_k_scale_inv, - has_zp // v_cache_maxptr - ? fake_perhead_scale - : quant_v_scale_inv, - nullptr, // o_maxptr - param.head_dim); // vo_head_dim - PD_CHECK(0, "speculative_attention unimplemented"); + ret = baidu::xpu::xfa::speculative_attention_decoder( + xpu_ctx->x_context(), + decode_output_ptr, // out + q_buf_ptr, // q + nullptr, // k + nullptr, // v + reinterpret_cast( + key_cache.data()), // k_cache + reinterpret_cast( + value_cache.data()), // v_cache + reinterpret_cast( + block_tables.data()), // block_tables + decoder_context_len_vp, // seq_lengths + decoder_batch_map_vp, // valid_batch + param.max_batch_size, // batch_num + q_len, // qlen + max_seq_len, // max_seq_len + param.head_num, // head_num + param.head_dim, // head_dim + param.kv_head_num, // kv_head_num + nullptr, // attn_mask + 1.0f / + std::sqrt(static_cast(param.head_dim)), // scale 【check】 + block_size, // block_size + max_block_per_seq, // max_blocks_per_seq + -1, // max_window_size + nullptr, // q_maxptr + has_zp // k_cache_maxptr + ? fake_perhead_scale + : quant_k_scale_inv, + has_zp // v_cache_maxptr + ? fake_perhead_scale + : quant_v_scale_inv, + nullptr, // o_maxptr + param.head_dim); // vo_head_dim + PD_CHECK(0, "speculative_attention unimplemented"); PD_CHECK(ret == api::SUCCESS, "xfa::speculative_attention_decoder failed."); if (!Eq_len) { @@ -702,11 +703,11 @@ std::vector BlockAttnKernel( reinterpret_cast(key_cache.data())), const_cast( reinterpret_cast(value_cache.data())), - vsl.usual_lod_vp, // seq_lod - vsl.slot_mapping_vp, // real_batch - param.batch_size, // batch_size - 1, // emb_batch_size = rotary_embs.dims()[1] = 1 - rope_max_seqlen, // max_seqlen + vsl.usual_lod_vp, // seq_lod + vsl.slot_mapping_vp, // real_batch + param.batch_size, // batch_size + 1, // emb_batch_size = rotary_embs.dims()[1] = 1 + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, @@ -777,7 +778,8 @@ std::vector BlockAttnKernel( ret = xftblock::xft_decoder_core_attenion_block< XPU_XType, XPU_CType, - XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high precision + XPU_XType>( // TGEMM = XPU_XType TODOlizan03: used high + // precision &xctx, &q_buf, &key_cache_tensor, @@ -867,8 +869,8 @@ std::vector BlockAttn( const paddle::optional& smooth, const paddle::optional& kv_signal_data_cpu, const paddle::optional& cachekv_signal_thread_cpu, - const std::string &pos_emb_type="NORMAL", - bool rope_3d=false) { + const std::string& pos_emb_type = "NORMAL", + bool rope_3d = false) { #define APPLY_KERNEL(TX, TC, TS) \ return BlockAttnKernel(qkv, \ key_cache, \ diff --git a/custom_ops/xpu_ops/src/ops/device/get_context_gm_max_mem_demand.cc b/custom_ops/xpu_ops/src/ops/device/get_context_gm_max_mem_demand.cc index 3677106ae..461a134df 100644 --- a/custom_ops/xpu_ops/src/ops/device/get_context_gm_max_mem_demand.cc +++ b/custom_ops/xpu_ops/src/ops/device/get_context_gm_max_mem_demand.cc @@ -12,13 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "xpu/plugin.h" -#include "xpu/xpuml.h" -#include #include #include -#include #include #include #include @@ -26,27 +21,31 @@ #include #include #include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" +#include "xpu/xpuml.h" std::vector GetMaxMemDemand(int64_t device_id) { - if (device_id == -1) { - device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); - } - phi::XPUPlace place(device_id); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + if (device_id == -1) { + device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); + } + phi::XPUPlace place(device_id); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64); + paddle::Tensor max_mem_demand = paddle::zeros({1}, paddle::DataType::INT64); - max_mem_demand.data()[0] = - xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand(); - return {max_mem_demand}; + max_mem_demand.data()[0] = + xpu_ctx->x_context()->_gm_mgr.get_max_mem_demand(); + return {max_mem_demand}; } std::vector> GetMaxMemDemandInferShape() { return {{1}}; } std::vector GetMaxMemDemandInferDtype() { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_OP(xpu_get_context_gm_max_mem_demand) diff --git a/custom_ops/xpu_ops/src/ops/device/get_free_global_memory.cc b/custom_ops/xpu_ops/src/ops/device/get_free_global_memory.cc index 228951bee..04de30e8a 100644 --- a/custom_ops/xpu_ops/src/ops/device/get_free_global_memory.cc +++ b/custom_ops/xpu_ops/src/ops/device/get_free_global_memory.cc @@ -12,13 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "xpu/plugin.h" -#include "xpu/xpuml.h" -#include #include #include -#include #include #include #include @@ -26,30 +21,35 @@ #include #include #include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" +#include "xpu/xpuml.h" std::vector GetFreeGlobalMemory(int64_t device_id) { - if (device_id == -1) { - device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); - } + if (device_id == -1) { + device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); + } - paddle::Tensor free_global_memory = - paddle::zeros({1}, paddle::DataType::INT64); + paddle::Tensor free_global_memory = + paddle::zeros({1}, paddle::DataType::INT64); - xpumlDevice_t device_handle; - xpumlInit(); - xpumlDeviceGetHandleByIndex(device_id, &device_handle); - xpumlMemory_t device_memory; - xpumlDeviceGetMemoryInfo(device_handle, &device_memory); - free_global_memory.data()[0] = device_memory.freeGlobalMemory; - return {free_global_memory}; + xpumlDevice_t device_handle; + xpumlInit(); + xpumlDeviceGetHandleByIndex(device_id, &device_handle); + xpumlMemory_t device_memory; + xpumlDeviceGetMemoryInfo(device_handle, &device_memory); + free_global_memory.data()[0] = device_memory.freeGlobalMemory; + return {free_global_memory}; } std::vector> GetFreeGlobalMemoryInferShape() { - return {{1}}; + return {{1}}; } std::vector GetFreeGlobalMemoryInferDtype() { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_OP(xpu_get_free_global_memory) diff --git a/custom_ops/xpu_ops/src/ops/device/get_total_global_memory.cc b/custom_ops/xpu_ops/src/ops/device/get_total_global_memory.cc index feba89449..dea1f05a9 100644 --- a/custom_ops/xpu_ops/src/ops/device/get_total_global_memory.cc +++ b/custom_ops/xpu_ops/src/ops/device/get_total_global_memory.cc @@ -12,13 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "xpu/plugin.h" -#include "xpu/xpuml.h" -#include #include #include -#include #include #include #include @@ -26,29 +21,34 @@ #include #include #include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" +#include "xpu/xpuml.h" std::vector GetTotalGlobalMemory(int64_t device_id) { - if (device_id == -1) { - device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); - } + if (device_id == -1) { + device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); + } - paddle::Tensor total_global_memory = - paddle::zeros({1}, paddle::DataType::INT64); - xpumlDevice_t device_handle; - xpumlInit(); - xpumlDeviceGetHandleByIndex(device_id, &device_handle); - xpumlMemory_t device_memory; - xpumlDeviceGetMemoryInfo(device_handle, &device_memory); - total_global_memory.data()[0] = device_memory.totalGlobalMemory; - return {total_global_memory}; + paddle::Tensor total_global_memory = + paddle::zeros({1}, paddle::DataType::INT64); + xpumlDevice_t device_handle; + xpumlInit(); + xpumlDeviceGetHandleByIndex(device_id, &device_handle); + xpumlMemory_t device_memory; + xpumlDeviceGetMemoryInfo(device_handle, &device_memory); + total_global_memory.data()[0] = device_memory.totalGlobalMemory; + return {total_global_memory}; } std::vector> GetTotalGlobalMemoryInferShape() { - return {{1}}; + return {{1}}; } std::vector GetTotalGlobalMemoryInferDtype() { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_OP(xpu_get_total_global_memory) diff --git a/custom_ops/xpu_ops/src/ops/device/get_used_global_memory.cc b/custom_ops/xpu_ops/src/ops/device/get_used_global_memory.cc index a6d05788e..badf01e95 100644 --- a/custom_ops/xpu_ops/src/ops/device/get_used_global_memory.cc +++ b/custom_ops/xpu_ops/src/ops/device/get_used_global_memory.cc @@ -12,13 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "xpu/plugin.h" -#include "xpu/xpuml.h" -#include #include #include -#include #include #include #include @@ -26,29 +21,34 @@ #include #include #include +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" +#include "xpu/xpuml.h" std::vector GetUsedGlobalMemory(int64_t device_id) { - if (device_id == -1) { - device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); - } + if (device_id == -1) { + device_id = phi::backends::xpu::GetXPUCurrentDeviceId(); + } - paddle::Tensor used_global_memory = - paddle::zeros({1}, paddle::DataType::INT64); - xpumlDevice_t device_handle; - xpumlInit(); - xpumlDeviceGetHandleByIndex(device_id, &device_handle); - xpumlMemory_t device_memory; - xpumlDeviceGetMemoryInfo(device_handle, &device_memory); - used_global_memory.data()[0] = device_memory.usedGlobalMemory; - return {used_global_memory}; + paddle::Tensor used_global_memory = + paddle::zeros({1}, paddle::DataType::INT64); + xpumlDevice_t device_handle; + xpumlInit(); + xpumlDeviceGetHandleByIndex(device_id, &device_handle); + xpumlMemory_t device_memory; + xpumlDeviceGetMemoryInfo(device_handle, &device_memory); + used_global_memory.data()[0] = device_memory.usedGlobalMemory; + return {used_global_memory}; } std::vector> GetUsedGlobalMemoryInferShape() { - return {{1}}; + return {{1}}; } std::vector GetUsedGlobalMemoryInferDtype() { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_OP(xpu_get_used_global_memory) diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 8d9aedcee..bc875b372 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -12,52 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "xpu/plugin.h" -#include -std::vector -GatherNextToken(const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &encoder_seq_lod, - const paddle::Tensor &encoder_batch_map, - const paddle::Tensor &decoder_batch_map, - const paddle::Tensor &encoder_seq_lod_cpu, - const paddle::Tensor &encoder_batch_map_cpu, - const paddle::Tensor &decoder_batch_map_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, - const paddle::optional &output_padding_offset, - int max_input_length) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - using XPUType = - typename XPUTypeTrait::Type; // only support bfloat16 - typedef paddle::bfloat16 data_t; - const int dim = tmp_out.dims()[1]; - const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; +std::vector GatherNextToken( + const paddle::Tensor &tmp_out, // [token_num, dim_embed] + const paddle::Tensor &cum_offsets, // [bsz, 1] + const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &encoder_batch_map, + const paddle::Tensor &decoder_batch_map, + const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &encoder_batch_map_cpu, + const paddle::Tensor &decoder_batch_map_cpu, + const paddle::Tensor &enc_batch_tensor, + const paddle::Tensor &dec_batch_tensor, + const paddle::optional &output_padding_offset, + int max_input_length) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + using XPUType = + typename XPUTypeTrait::Type; // only support bfloat16 + typedef paddle::bfloat16 data_t; + const int dim = tmp_out.dims()[1]; + const int bsz = cum_offsets.shape()[0]; + int enc_batch = enc_batch_tensor.data()[0]; + int dec_batch = dec_batch_tensor.data()[0]; - baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ - const_cast(encoder_seq_lod_cpu.data()), - enc_batch + 1, const_cast(encoder_seq_lod.data())}; - baidu::xpu::api::VectorParam encoder_batch_map_vp{ - const_cast(encoder_batch_map_cpu.data()), enc_batch, - const_cast(encoder_batch_map.data())}; - baidu::xpu::api::VectorParam decoder_batch_map_vp{ - const_cast(decoder_batch_map_cpu.data()), dec_batch, - const_cast(decoder_batch_map.data())}; + baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ + const_cast(encoder_seq_lod_cpu.data()), + enc_batch + 1, + const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam encoder_batch_map_vp{ + const_cast(encoder_batch_map_cpu.data()), + enc_batch, + const_cast(encoder_batch_map.data())}; + baidu::xpu::api::VectorParam decoder_batch_map_vp{ + const_cast(decoder_batch_map_cpu.data()), + dec_batch, + const_cast(decoder_batch_map.data())}; - auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place()); + auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place()); - int r = baidu::xpu::api::plugin::eb_gather_next_token( - xpu_ctx->x_context(), - reinterpret_cast(tmp_out.data()), - reinterpret_cast(out.data()), encoder_seqs_lods_vp, - encoder_batch_map_vp, decoder_batch_map_vp, dim); - return {out}; + int r = baidu::xpu::api::plugin::eb_gather_next_token( + xpu_ctx->x_context(), + reinterpret_cast(tmp_out.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + return {out}; } std::vector> GatherNextTokenInferShape( @@ -72,12 +77,12 @@ std::vector> GatherNextTokenInferShape( const std::vector &enc_batch_tensor_shape, const std::vector &dec_batch_tensor_shape, const paddle::optional> &output_padding_offset_shape) { - if (output_padding_offset_shape) { - PD_THROW("speculative decoding is not supported in XPU."); - } - int64_t bsz = cum_offsets_shape[0]; - int64_t dim_embed = tmp_out_shape[1]; - return {{bsz, dim_embed}}; + if (output_padding_offset_shape) { + PD_THROW("speculative decoding is not supported in XPU."); + } + int64_t bsz = cum_offsets_shape[0]; + int64_t dim_embed = tmp_out_shape[1]; + return {{bsz, dim_embed}}; } std::vector GatherNextTokenInferDtype( @@ -92,14 +97,20 @@ std::vector GatherNextTokenInferDtype( const paddle::DataType &enc_batch_tensor_dtype, const paddle::DataType &dec_batch_tensor_dtype, const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + return {tmp_out_dtype}; } PD_BUILD_OP(gather_next_token) - .Inputs({"tmp_out", "cum_offsets", "encoder_seq_lod", "encoder_batch_map", - "decoder_batch_map", "encoder_seq_lod_cpu", - "encoder_batch_map_cpu", "decoder_batch_map_cpu", - "enc_batch_tensor", "dec_batch_tensor", + .Inputs({"tmp_out", + "cum_offsets", + "encoder_seq_lod", + "encoder_batch_map", + "decoder_batch_map", + "encoder_seq_lod_cpu", + "encoder_batch_map_cpu", + "decoder_batch_map_cpu", + "enc_batch_tensor", + "dec_batch_tensor", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) diff --git a/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc b/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc index 30ca6d269..ca21a1d16 100644 --- a/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc +++ b/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc @@ -14,43 +14,48 @@ #include "paddle/extension.h" -std::vector GetImgBoundaries(const paddle::Tensor& task_input_ids, - const paddle::Tensor& grid_thw, - const int64_t image_patch_id) { - // All tensor in cpu - auto input_ids_ptr = task_input_ids.data(); - int64_t seq_lens_origin = task_input_ids.numel(); - auto grid_thw_ptr = grid_thw.data(); +std::vector GetImgBoundaries( + const paddle::Tensor& task_input_ids, + const paddle::Tensor& grid_thw, + const int64_t image_patch_id) { + // All tensor in cpu + auto input_ids_ptr = task_input_ids.data(); + int64_t seq_lens_origin = task_input_ids.numel(); + auto grid_thw_ptr = grid_thw.data(); - int token_times = 4; - int token_idx = 0; - int image_idx = 0; - std::vector img_boundaries, img_nums; - img_boundaries.emplace_back(0); - img_nums.emplace_back(0); - while (token_idx < seq_lens_origin) { - if (input_ids_ptr[token_idx] != image_patch_id) { - do { - token_idx++; - } while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id); - } else { - int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times; - image_idx++; - token_idx += cur_image_token_len; - } - img_boundaries.emplace_back(token_idx); - img_nums.emplace_back(image_idx); + int token_times = 4; + int token_idx = 0; + int image_idx = 0; + std::vector img_boundaries, img_nums; + img_boundaries.emplace_back(0); + img_nums.emplace_back(0); + while (token_idx < seq_lens_origin) { + if (input_ids_ptr[token_idx] != image_patch_id) { + do { + token_idx++; + } while (token_idx < seq_lens_origin && + input_ids_ptr[token_idx] != image_patch_id); + } else { + int cur_image_token_len = + (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / + token_times; + image_idx++; + token_idx += cur_image_token_len; } + img_boundaries.emplace_back(token_idx); + img_nums.emplace_back(image_idx); + } - int64_t num_img_boundaries = static_cast(img_boundaries.size()); - auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace()); + int64_t num_img_boundaries = static_cast(img_boundaries.size()); + auto out = paddle::full( + {2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace()); - for (int i = 0; i < num_img_boundaries; i++) { - out.data()[i] = img_boundaries[i]; - out.data()[num_img_boundaries + i] = img_nums[i]; - } + for (int i = 0; i < num_img_boundaries; i++) { + out.data()[i] = img_boundaries[i]; + out.data()[num_img_boundaries + i] = img_nums[i]; + } - return {out}; + return {out}; } PD_BUILD_OP(get_img_boundaries) diff --git a/custom_ops/xpu_ops/src/ops/get_output.cc b/custom_ops/xpu_ops/src/ops/get_output.cc index 6886f441f..a1150e008 100644 --- a/custom_ops/xpu_ops/src/ops/get_output.cc +++ b/custom_ops/xpu_ops/src/ops/get_output.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" #include #include #include #include #include #include "msg_utils.h" +#include "paddle/extension.h" -void GetOutputKVSignal(const paddle::Tensor& x, +void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) { int msg_queue_id = 1024 + rank_id; @@ -28,7 +28,7 @@ void GetOutputKVSignal(const paddle::Tensor& x, static key_t key = ftok("/opt/", msg_queue_id); static int msgid = msgget(key, IPC_CREAT | 0666); - int* out_data = const_cast(x.data()); + int *out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); @@ -48,69 +48,72 @@ void GetOutputKVSignal(const paddle::Tensor& x, return; } -void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag, +void GetOutput(const paddle::Tensor &x, + int64_t rank_id, + bool wait_flag, int msg_queue_id) { - if (rank_id > 0) { - return; - } - static struct msgdata msg_rcv; - if (const char *inference_msg_queue_id_env_p = - std::getenv("INFERENCE_MSG_QUEUE_ID")) { - std::string inference_msg_queue_id_env_str( - inference_msg_queue_id_env_p); - int inference_msg_queue_id_from_env = - std::stoi(inference_msg_queue_id_env_str); -#ifdef GET_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " - << inference_msg_queue_id_from_env << std::endl; -#endif - msg_queue_id = inference_msg_queue_id_from_env; - } - static key_t key = ftok("/dev/shm", msg_queue_id); - static int msgid = msgget(key, IPC_CREAT | 0666); - -#ifdef GET_OUTPUT_DEBUG - std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl; - std::cout << "get_output key: " << key << std::endl; - std::cout << "get_output msgid: " << msgid << std::endl; - std::cout << "get_output wait_flag: " << wait_flag << std::endl; -#endif - - int64_t *out_data = const_cast(x.data()); - int ret = -1; - if (!wait_flag) { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); - } else { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); - } - -#ifdef GET_OUTPUT_DEBUG - std::cout << "get_output finish msgrcv" << std::endl; -#endif - if (ret == -1) { - out_data[0] = -2; - out_data[1] = 0; - return; - } - int bsz = msg_rcv.mtext[1]; - - for (int64_t i = 0; i < bsz + 2; i++) { - out_data[i] = (int64_t)msg_rcv.mtext[i]; - } -#ifdef GET_OUTPUT_DEBUG - std::cout << "get_output finished: " << msgid << std::endl; -#endif - + if (rank_id > 0) { return; + } + static struct msgdata msg_rcv; + if (const char *inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef GET_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static key_t key = ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); + +#ifdef GET_OUTPUT_DEBUG + std::cout << "get_output msg_queue_id: " << msg_queue_id << std::endl; + std::cout << "get_output key: " << key << std::endl; + std::cout << "get_output msgid: " << msgid << std::endl; + std::cout << "get_output wait_flag: " << wait_flag << std::endl; +#endif + + int64_t *out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); + } + +#ifdef GET_OUTPUT_DEBUG + std::cout << "get_output finish msgrcv" << std::endl; +#endif + if (ret == -1) { + out_data[0] = -2; + out_data[1] = 0; + return; + } + int bsz = msg_rcv.mtext[1]; + + for (int64_t i = 0; i < bsz + 2; i++) { + out_data[i] = (int64_t)msg_rcv.mtext[i]; + } +#ifdef GET_OUTPUT_DEBUG + std::cout << "get_output finished: " << msgid << std::endl; +#endif + + return; } void GetOutputStatic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) { - GetOutput(x, rank_id, wait_flag, 1); + GetOutput(x, rank_id, wait_flag, 1); } -void GetOutputDynamic(const paddle::Tensor &x, int64_t rank_id, bool wait_flag, +void GetOutputDynamic(const paddle::Tensor &x, + int64_t rank_id, + bool wait_flag, int msg_queue_id) { - GetOutput(x, rank_id, wait_flag, msg_queue_id); + GetOutput(x, rank_id, wait_flag, msg_queue_id); } PD_BUILD_OP(get_output) diff --git a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc index e83cecb19..7c1824372 100644 --- a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc @@ -20,44 +20,43 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const paddle::Tensor &cum_offsets, const paddle::Tensor &token_num, const paddle::Tensor &seq_len) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::full( - {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); - auto batch_id_per_token = paddle::full( - {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - int r = baidu::xpu::api::plugin::get_padding_offset( - xpu_ctx->x_context(), - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - x_remove_padding.data(), - input_ids.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); - return {x_remove_padding, - cum_offsets_out, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k}; + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::full( + {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); + auto batch_id_per_token = paddle::full( + {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + int r = baidu::xpu::api::plugin::get_padding_offset( + xpu_ctx->x_context(), + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + x_remove_padding.data(), + input_ids.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); + return {x_remove_padding, + cum_offsets_out, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k}; } std::vector> GetPaddingOffsetInferShape( @@ -65,9 +64,9 @@ std::vector> GetPaddingOffsetInferShape( const std::vector &cum_offsets_shape, const std::vector &token_num_shape, const std::vector &seq_len_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( @@ -75,11 +74,11 @@ std::vector GetPaddingOffsetInferDtype( const paddle::DataType &cum_offsets_dtype, const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; + return {input_ids_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype}; } PD_BUILD_OP(get_padding_offset) diff --git a/custom_ops/xpu_ops/src/ops/get_token_penalty_multi_scores.cc b/custom_ops/xpu_ops/src/ops/get_token_penalty_multi_scores.cc index 95b5b7977..6e579e4b8 100644 --- a/custom_ops/xpu_ops/src/ops/get_token_penalty_multi_scores.cc +++ b/custom_ops/xpu_ops/src/ops/get_token_penalty_multi_scores.cc @@ -12,70 +12,97 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -#include -void TokenPenaltyMultiScores( - const paddle::Tensor &pre_ids, const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - int64_t bs = logits.shape()[0]; - PADDLE_ENFORCE_LE( - bs, 640, - phi::errors::InvalidArgument( - "Only support bsz <= 1024, but received bsz is %d", bs)); - int64_t length = logits.shape()[1]; - int64_t length_id = pre_ids.shape()[1]; - int64_t length_bad_words = bad_tokens.shape()[0]; - int64_t end_length = eos_token_id.shape()[0]; - switch (logits.type()) { +void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, + const paddle::Tensor &logits, + const paddle::Tensor &penalty_scores, + const paddle::Tensor &frequency_scores, + const paddle::Tensor &presence_scores, + const paddle::Tensor &temperatures, + const paddle::Tensor &bad_tokens, + const paddle::Tensor &cur_len, + const paddle::Tensor &min_len, + const paddle::Tensor &eos_token_id) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + int64_t bs = logits.shape()[0]; + PADDLE_ENFORCE_LE( + bs, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 1024, but received bsz is %d", bs)); + int64_t length = logits.shape()[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t length_bad_words = bad_tokens.shape()[0]; + int64_t end_length = eos_token_id.shape()[0]; + switch (logits.type()) { case paddle::DataType::FLOAT16: { - using XPUType = typename XPUTypeTrait::Type; - typedef paddle::float16 data_t; - int r = baidu::xpu::api::plugin::token_penalty_multi_scores( - xpu_ctx->x_context(), pre_ids.data(), - reinterpret_cast( - const_cast(logits.data())), - reinterpret_cast(penalty_scores.data()), - reinterpret_cast(frequency_scores.data()), - reinterpret_cast(presence_scores.data()), - temperatures.data(), cur_len.data(), - min_len.data(), eos_token_id.data(), - bad_tokens.data(), bs, length, length_id, end_length, - length_bad_words); - PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::float16 data_t; + int r = baidu::xpu::api::plugin::token_penalty_multi_scores( + xpu_ctx->x_context(), + pre_ids.data(), + reinterpret_cast( + const_cast(logits.data())), + reinterpret_cast(penalty_scores.data()), + reinterpret_cast(frequency_scores.data()), + reinterpret_cast(presence_scores.data()), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + bs, + length, + length_id, + end_length, + length_bad_words); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); } break; case paddle::DataType::FLOAT32: { - int r = baidu::xpu::api::plugin::token_penalty_multi_scores( - xpu_ctx->x_context(), pre_ids.data(), - const_cast(logits.data()), - penalty_scores.data(), frequency_scores.data(), - presence_scores.data(), temperatures.data(), - cur_len.data(), min_len.data(), - eos_token_id.data(), bad_tokens.data(), bs, - length, length_id, end_length, length_bad_words); - PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); + int r = baidu::xpu::api::plugin::token_penalty_multi_scores( + xpu_ctx->x_context(), + pre_ids.data(), + const_cast(logits.data()), + penalty_scores.data(), + frequency_scores.data(), + presence_scores.data(), + temperatures.data(), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bad_tokens.data(), + bs, + length, + length_id, + end_length, + length_bad_words); + PD_CHECK(r == 0, "xpu::plugin::token_penalty_multi_scores failed."); } break; default: - PD_THROW("NOT supported data type. " - "Only float16 and float32 are supported. "); - break; - } + PD_THROW( + "NOT supported data type. " + "Only float16 and float32 are supported. "); + break; + } } PD_BUILD_OP(get_token_penalty_multi_scores) - .Inputs({"pre_ids", "logits", "penalty_scores", "frequency_scores", - "presence_scores", "temperatures", "bad_tokens", "cur_len", - "min_len", "eos_token_id"}) + .Inputs({"pre_ids", + "logits", + "penalty_scores", + "frequency_scores", + "presence_scores", + "temperatures", + "bad_tokens", + "cur_len", + "min_len", + "eos_token_id"}) .Outputs({"logits_out"}) .SetInplaceMap({{"logits", "logits_out"}}) .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)); diff --git a/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc index 98e0e6648..0b064044b 100644 --- a/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc +++ b/custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc @@ -72,7 +72,7 @@ void MoeExpertFFNImpl(xftblock::Tensor* ffn_in, is_padding_input ? token_num_info : nullptr, expert_num, 1, // moe_topk - 0, // group_size + 0, // group_size ffn1_out_shape.size() == 2 ? xftblock::MoeFCInputMode::DENSE : xftblock::MoeFCInputMode::SPARSE); PD_CHECK(ret == 0); diff --git a/custom_ops/xpu_ops/src/ops/moe_layer.cc b/custom_ops/xpu_ops/src/ops/moe_layer.cc index 937580d2c..2c948ffd7 100644 --- a/custom_ops/xpu_ops/src/ops/moe_layer.cc +++ b/custom_ops/xpu_ops/src/ops/moe_layer.cc @@ -29,210 +29,246 @@ namespace xftblock = baidu::xpu::xftblock; namespace api = baidu::xpu::api; -template struct fused_moe_ffn_trait { - using GEMM_TYPE = TW; +template +struct fused_moe_ffn_trait { + using GEMM_TYPE = TW; }; -template <> struct fused_moe_ffn_trait { - using GEMM_TYPE = float; +template <> +struct fused_moe_ffn_trait { + using GEMM_TYPE = float; }; -template <> struct fused_moe_ffn_trait { - using GEMM_TYPE = float; +template <> +struct fused_moe_ffn_trait { + using GEMM_TYPE = float; }; -template <> struct fused_moe_ffn_trait { - using GEMM_TYPE = int4_wo_int15; +template <> +struct fused_moe_ffn_trait { + using GEMM_TYPE = int4_wo_int15; }; template std::vector MoeLayerKernel( - const paddle::Tensor &x, const paddle::Tensor &gate_weight, + const paddle::Tensor &x, + const paddle::Tensor &gate_weight, const paddle::optional &gate_correction_bias, - const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, + const paddle::Tensor &up_gate_proj_weight, + const paddle::Tensor &down_proj_weight, const paddle::optional &up_gate_proj_bias, const paddle::optional &down_proj_bias, const paddle::optional &up_gate_proj_weight_scale, const paddle::optional &down_proj_weight_scale, - const paddle::optional &down_proj_in_scale, // not support - const std::string &quant_method, const int moe_top_k, + const paddle::optional &down_proj_in_scale, // not support + const std::string &quant_method, + const int moe_top_k, const bool moe_group) { - // std::cout << "[Op Debug] enter moe layer" << std::endl; - using XPU_TX = typename XPUTypeTrait::Type; - using XPU_TW = typename XPUTypeTrait::Type; - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr); - auto rt_guard = xctx.get_rt_guard(); + // std::cout << "[Op Debug] enter moe layer" << std::endl; + using XPU_TX = typename XPUTypeTrait::Type; + using XPU_TW = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr); + auto rt_guard = xctx.get_rt_guard(); + + const auto xtype = x.dtype(); + auto x_dims = x.shape(); + auto up_gate_proj_dims = up_gate_proj_weight.shape(); + PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2."); + PD_CHECK(up_gate_proj_dims.size() == 3, + "up_gate_proj_dims.size() should be 3."); + PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, + "down_proj_in_scale not support."); + if (quant_method == "weight_only_int4") { + PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2, + "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " + "[e,n,k])."); + } else { + PD_CHECK(x_dims[1] == up_gate_proj_dims[2], + "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " + "[e,n,k])."); + } + + int token_num = x_dims[0]; + int hidden_dim = x_dims[1]; + int expert_num = up_gate_proj_dims[0]; + int inter_dim = up_gate_proj_dims[1]; + int outer_dim = inter_dim / 2; + + paddle::Tensor fused_moe_out = paddle::empty_like(x); + + auto x_mpart_shape = x_dims; + int MPART_SIZE = 2048; + if (const char *env_val = std::getenv("XPU_MPART_SIZE")) { + MPART_SIZE = std::atoi(env_val); + } + int bsz = x_dims[0]; + for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) { + auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz); + auto x_offset = m_part_start * hidden_dim; + x_mpart_shape[0] = m_part_end - m_part_start; + int ret = -1; + auto xftblock_tx = xftblock::DataTypeToEnum::value; + auto xftblock_tw = xftblock::DataTypeToEnum::value; + // input + output + xftblock::Tensor xin( + const_cast(x.data() + x_offset), xftblock_tx, x_mpart_shape); + + xftblock::Tensor xout(fused_moe_out.mutable_data() + x_offset, + xftblock_tx, + x_mpart_shape); + // gate + xftblock::Tensor xgate_w(const_cast(gate_weight.data()), + xftblock::DataType::DT_FLOAT, + gate_weight.shape()); + std::shared_ptr xgate_correct_bias; + if (gate_correction_bias.get_ptr()) { + xgate_correct_bias = std::make_shared( + const_cast(gate_correction_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + gate_correction_bias.get_ptr()->shape()); + } + + // up_gate_proj + down_proj + std::shared_ptr xup_gate_proj_w, xdown_proj_w; + + if (std::is_same::value) { + xup_gate_proj_w = std::make_shared( + const_cast(up_gate_proj_weight.data()), + nullptr, + const_cast( + up_gate_proj_weight_scale.get_ptr() + ? up_gate_proj_weight_scale.get_ptr()->data() + : nullptr), + xftblock_tw, + std::vector{expert_num, inter_dim, hidden_dim}); + + xdown_proj_w = std::make_shared( + const_cast(down_proj_weight.data()), + nullptr, + const_cast( + down_proj_weight_scale.get_ptr() + ? down_proj_weight_scale.get_ptr()->data() + : nullptr), + xftblock_tw, + std::vector{expert_num, hidden_dim, outer_dim}); - const auto xtype = x.dtype(); - auto x_dims = x.shape(); - auto up_gate_proj_dims = up_gate_proj_weight.shape(); - PD_CHECK(x_dims.size() == 2, "x_dims.size() should be 2."); - PD_CHECK(up_gate_proj_dims.size() == 3, "up_gate_proj_dims.size() should be 3."); - PD_CHECK(down_proj_in_scale.get_ptr() == nullptr, "down_proj_in_scale not support."); - if (quant_method == "weight_only_int4") { - PD_CHECK(x_dims[1] == up_gate_proj_dims[2] * 2, - "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " - "[e,n,k])."); } else { - PD_CHECK(x_dims[1] == up_gate_proj_dims[2], - "x_dims[1] should equal to up_gate_proj_dims[2], (weight must be " - "[e,n,k])."); + xup_gate_proj_w = std::make_shared( + const_cast(up_gate_proj_weight.data()), + nullptr, + const_cast( + up_gate_proj_weight_scale.get_ptr() + ? up_gate_proj_weight_scale.get_ptr()->data() + : nullptr), + xftblock_tw, + std::vector{expert_num, inter_dim, hidden_dim}); + + xdown_proj_w = std::make_shared( + const_cast(down_proj_weight.data()), + nullptr, + const_cast( + down_proj_weight_scale.get_ptr() + ? down_proj_weight_scale.get_ptr()->data() + : nullptr), + xftblock_tw, + std::vector{expert_num, hidden_dim, outer_dim}); } - - int token_num = x_dims[0]; - int hidden_dim = x_dims[1]; - int expert_num = up_gate_proj_dims[0]; - int inter_dim = up_gate_proj_dims[1]; - int outer_dim = inter_dim / 2; - - paddle::Tensor fused_moe_out = paddle::empty_like(x); - - auto x_mpart_shape = x_dims; - int MPART_SIZE = 2048; - if (const char* env_val = std::getenv("XPU_MPART_SIZE")) { - MPART_SIZE = std::atoi(env_val); + std::shared_ptr xup_gate_proj_bias; + std::shared_ptr xdown_proj_bias; + if (up_gate_proj_bias.get_ptr()) { + xup_gate_proj_bias = std::make_shared( + const_cast(up_gate_proj_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + up_gate_proj_bias.get_ptr()->shape()); } - int bsz = x_dims[0]; - for (int m_part_start = 0; m_part_start < bsz; m_part_start += MPART_SIZE) { - auto m_part_end = std::min(m_part_start + MPART_SIZE, bsz); - auto x_offset = m_part_start * hidden_dim; - x_mpart_shape[0] = m_part_end - m_part_start; - int ret = -1; - auto xftblock_tx = xftblock::DataTypeToEnum::value; - auto xftblock_tw = xftblock::DataTypeToEnum::value; - // input + output - xftblock::Tensor xin(const_cast(x.data() + x_offset), xftblock_tx, - x_mpart_shape); - - xftblock::Tensor xout(fused_moe_out.mutable_data() + x_offset, xftblock_tx, - x_mpart_shape); - // gate - xftblock::Tensor xgate_w(const_cast(gate_weight.data()), - xftblock::DataType::DT_FLOAT, gate_weight.shape()); - std::shared_ptr xgate_correct_bias; - if (gate_correction_bias.get_ptr()) { - xgate_correct_bias = std::make_shared( - const_cast(gate_correction_bias.get_ptr()->data()), - xftblock::DataType::DT_FLOAT, - gate_correction_bias.get_ptr()->shape()); - } - - // up_gate_proj + down_proj - std::shared_ptr xup_gate_proj_w, xdown_proj_w; - - if (std::is_same::value) { - xup_gate_proj_w = std::make_shared( - const_cast(up_gate_proj_weight.data()), nullptr, - const_cast(up_gate_proj_weight_scale.get_ptr() - ? up_gate_proj_weight_scale.get_ptr()->data() - : nullptr), - xftblock_tw, - std::vector{expert_num, inter_dim, hidden_dim}); - - xdown_proj_w = std::make_shared( - const_cast(down_proj_weight.data()), nullptr, - const_cast(down_proj_weight_scale.get_ptr() - ? down_proj_weight_scale.get_ptr()->data() - : nullptr), - xftblock_tw, - std::vector{expert_num, hidden_dim, outer_dim}); - - } else { - xup_gate_proj_w = std::make_shared( - const_cast(up_gate_proj_weight.data()), nullptr, - const_cast(up_gate_proj_weight_scale.get_ptr() - ? up_gate_proj_weight_scale.get_ptr()->data() - : nullptr), - xftblock_tw, - std::vector{expert_num, inter_dim, hidden_dim} - ); - - xdown_proj_w = std::make_shared( - const_cast(down_proj_weight.data()), nullptr, - const_cast(down_proj_weight_scale.get_ptr() - ? down_proj_weight_scale.get_ptr()->data() - : nullptr), - xftblock_tw, - std::vector{expert_num, hidden_dim, outer_dim} - ); - } - std::shared_ptr xup_gate_proj_bias; - std::shared_ptr xdown_proj_bias; - if (up_gate_proj_bias.get_ptr()) { - xup_gate_proj_bias = std::make_shared( - const_cast(up_gate_proj_bias.get_ptr()->data()), - xftblock::DataType::DT_FLOAT, up_gate_proj_bias.get_ptr()->shape()); - } - if (down_proj_bias.get_ptr()) { - xdown_proj_bias = std::make_shared( - const_cast(down_proj_bias.get_ptr()->data()), - xftblock::DataType::DT_FLOAT, down_proj_bias.get_ptr()->shape()); - } - // std::cout << "[Op Debug] start init moe_ffn weight and bias" << - // std::endl; MoeFFNWeight - xftblock::MoeFFNWeight moe_ffn_w_struct; - moe_ffn_w_struct.gate_weight = &xgate_w; - moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get(); - moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get(); - moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get(); - moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get(); - moe_ffn_w_struct.score_bias = xgate_correct_bias.get(); - // MoeFFNParam - xftblock::MoeFFNParam moe_ffn_param; - moe_ffn_param.expert_num = expert_num; - moe_ffn_param.moe_top_k = moe_top_k; - moe_ffn_param.fast_swiglu = true; - - // std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl; - - using XPU_TGEMM = typename fused_moe_ffn_trait::GEMM_TYPE; - ret = baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token< - XPU_TX, XPU_TW, XPU_TX, XPU_TGEMM>(&xctx, &xin, &xout, moe_ffn_w_struct, - moe_ffn_param); - PD_CHECK(ret == 0, - "xftblock::moe_ffn_block_sorted_castte_per_token failed"); + if (down_proj_bias.get_ptr()) { + xdown_proj_bias = std::make_shared( + const_cast(down_proj_bias.get_ptr()->data()), + xftblock::DataType::DT_FLOAT, + down_proj_bias.get_ptr()->shape()); } + // std::cout << "[Op Debug] start init moe_ffn weight and bias" << + // std::endl; MoeFFNWeight + xftblock::MoeFFNWeight moe_ffn_w_struct; + moe_ffn_w_struct.gate_weight = &xgate_w; + moe_ffn_w_struct.ffn_inter_weights = xup_gate_proj_w.get(); + moe_ffn_w_struct.ffn_inter_bias = xup_gate_proj_bias.get(); + moe_ffn_w_struct.ffn_outer_weights = xdown_proj_w.get(); + moe_ffn_w_struct.ffn_outer_bias = xdown_proj_bias.get(); + moe_ffn_w_struct.score_bias = xgate_correct_bias.get(); + // MoeFFNParam + xftblock::MoeFFNParam moe_ffn_param; + moe_ffn_param.expert_num = expert_num; + moe_ffn_param.moe_top_k = moe_top_k; + moe_ffn_param.fast_swiglu = true; - return {fused_moe_out}; + // std::cout << "[Op Debug] pre in xvfblock moe_ffn" << std::endl; + + using XPU_TGEMM = typename fused_moe_ffn_trait::GEMM_TYPE; + ret = + baidu::xpu::xftblock::moe_ffn_block_sorted_castte_per_token( + &xctx, &xin, &xout, moe_ffn_w_struct, moe_ffn_param); + PD_CHECK(ret == 0, + "xftblock::moe_ffn_block_sorted_castte_per_token failed"); + } + + return {fused_moe_out}; } -std::vector -MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight, - const paddle::optional &gate_correction_bias, - const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, - const paddle::optional &up_gate_proj_bias, - const paddle::optional &down_proj_bias, - const paddle::optional &up_gate_proj_weight_scale, - const paddle::optional &down_proj_weight_scale, - const paddle::optional &down_proj_in_scale, - const std::string &quant_method, const int moe_top_k, - const bool moe_group) { - const auto x_type = x.dtype(); - const auto w_type = up_gate_proj_weight.dtype(); +std::vector MoeLayer( + const paddle::Tensor &x, + const paddle::Tensor &gate_weight, + const paddle::optional &gate_correction_bias, + const paddle::Tensor &up_gate_proj_weight, + const paddle::Tensor &down_proj_weight, + const paddle::optional &up_gate_proj_bias, + const paddle::optional &down_proj_bias, + const paddle::optional &up_gate_proj_weight_scale, + const paddle::optional &down_proj_weight_scale, + const paddle::optional &down_proj_in_scale, + const std::string &quant_method, + const int moe_top_k, + const bool moe_group) { + const auto x_type = x.dtype(); + const auto w_type = up_gate_proj_weight.dtype(); -#define APPLY_MOE_LAYER_KERNEL(TX, TW) \ - return MoeLayerKernel( \ - x, gate_weight, gate_correction_bias, up_gate_proj_weight, down_proj_weight, \ - up_gate_proj_bias, down_proj_bias, up_gate_proj_weight_scale, down_proj_weight_scale, \ - down_proj_in_scale, quant_method, moe_top_k, moe_group); +#define APPLY_MOE_LAYER_KERNEL(TX, TW) \ + return MoeLayerKernel(x, \ + gate_weight, \ + gate_correction_bias, \ + up_gate_proj_weight, \ + down_proj_weight, \ + up_gate_proj_bias, \ + down_proj_bias, \ + up_gate_proj_weight_scale, \ + down_proj_weight_scale, \ + down_proj_in_scale, \ + quant_method, \ + moe_top_k, \ + moe_group); - // TODO(mayang02): how to use quant_method? - if (x_type == paddle::DataType::BFLOAT16 && - w_type == paddle::DataType::BFLOAT16) { - APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16); - } else if (x_type == paddle::DataType::BFLOAT16 && - quant_method == "weight_only_int8") { - APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t); - } else if (x_type == paddle::DataType::BFLOAT16 && - quant_method == "weight_only_int4") { - APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t); - } else { - PD_THROW("MoeLayer not support x_type=", static_cast(x_type), - ", w_type=", static_cast(w_type), - ", quant_method=", quant_method); - return {}; - } + // TODO(mayang02): how to use quant_method? + if (x_type == paddle::DataType::BFLOAT16 && + w_type == paddle::DataType::BFLOAT16) { + APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, paddle::bfloat16); + } else if (x_type == paddle::DataType::BFLOAT16 && + quant_method == "weight_only_int8") { + APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int8_t); + } else if (x_type == paddle::DataType::BFLOAT16 && + quant_method == "weight_only_int4") { + APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t); + } else { + PD_THROW("MoeLayer not support x_type=", + static_cast(x_type), + ", w_type=", + static_cast(w_type), + ", quant_method=", + quant_method); + return {}; + } #undef APPLY_MOE_LAYER_KERNEL } @@ -244,14 +280,16 @@ std::vector> MoeLayerInferShape( const std::vector &down_proj_weight_shape, const paddle::optional> &up_gate_proj_bias_shape, const paddle::optional> &down_proj_bias_shape, - const paddle::optional> &up_gate_proj_weight_scale_shape, + const paddle::optional> + &up_gate_proj_weight_scale_shape, const paddle::optional> &down_proj_weight_scale_shape, const paddle::optional> &down_proj_in_scale_shape) { - return {x_shape}; + return {x_shape}; } std::vector MoeLayerInferDtype( - const paddle::DataType &x_dtype, const paddle::DataType &gate_weight_dtype, + const paddle::DataType &x_dtype, + const paddle::DataType &gate_weight_dtype, const paddle::optional &gate_correction_bias_dtype, const paddle::DataType &up_gate_proj_weight_dtype, const paddle::DataType &down_proj_weight_dtype, @@ -260,12 +298,16 @@ std::vector MoeLayerInferDtype( const paddle::optional &up_gate_proj_weight_scale_dtype, const paddle::optional &down_proj_weight_scale_dtype, const paddle::optional &down_proj_in_scale_dtype) { - return {x_dtype}; + return {x_dtype}; } -PD_BUILD_OP(xpu_moe_layer) // fused_moe - .Inputs({"x", "gate_weight", paddle::Optional("gate_correction_bias"), - "up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"), +PD_BUILD_OP(xpu_moe_layer) // fused_moe + .Inputs({"x", + "gate_weight", + paddle::Optional("gate_correction_bias"), + "up_gate_proj_weight", + "down_proj_weight", + paddle::Optional("up_gate_proj_bias"), paddle::Optional("down_proj_bias"), paddle::Optional("up_gate_proj_weight_scale"), paddle::Optional("down_proj_weight_scale"), diff --git a/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc b/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc index 500cfbf43..7d01430f2 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/alloc_cache_pinned.cc @@ -14,9 +14,9 @@ #include // NOLINT #include "cuda_runtime_api.h" // NOLINT +#include "ops/pybind/pybind.h" #include "paddle/extension.h" #include "xpu/runtime.h" -#include "ops/pybind/pybind.h" void check_xpu_error(int error) { if (error != XPU_SUCCESS) { diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 832bdbf69..288f43226 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -33,13 +33,13 @@ void prof_start(); void prof_stop(); -void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, - const paddle::Tensor &seq_lens_this_time_tensor, - const paddle::Tensor &seq_lens_decoder_tensor, +void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor, + const paddle::Tensor& seq_lens_this_time_tensor, + const paddle::Tensor& seq_lens_decoder_tensor, const int rank, const int num_layers); -void GetOutputKVSignal(const paddle::Tensor &x, +void GetOutputKVSignal(const paddle::Tensor& x, int64_t rank_id, bool wait_flag); @@ -70,8 +70,8 @@ std::vector BlockAttn( const paddle::optional& smooth, const paddle::optional& kv_signal_data_cpu, const paddle::optional& cachekv_signal_thread_cpu, - const std::string &pos_emb_type="NORMAL", - bool rope_3d=false); + const std::string& pos_emb_type = "NORMAL", + bool rope_3d = false); std::vector MoERedundantTopKSelect( const paddle::Tensor& gating_logits, @@ -477,7 +477,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("bias"), py::arg("weight_dtype"), py::arg("arch"), - py::arg("group_size")=-1); + py::arg("group_size") = -1); m.def("ep_moe_expert_combine", &MoeEPCombine, diff --git a/custom_ops/xpu_ops/src/ops/recover_decode_task.cc b/custom_ops/xpu_ops/src/ops/recover_decode_task.cc index 34871f0d3..8151ea12b 100644 --- a/custom_ops/xpu_ops/src/ops/recover_decode_task.cc +++ b/custom_ops/xpu_ops/src/ops/recover_decode_task.cc @@ -18,32 +18,31 @@ #include "xpu/plugin.h" void RecoverDecodeTask(const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &block_tables, - const paddle::Tensor &is_block_step, - const int block_size) { -phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - const int bsz = seq_lens_this_time.shape()[0]; - const int block_num_per_seq = block_tables.shape()[1]; - int r = baidu::xpu::api::plugin::recover_decode_task( - xpu_ctx->x_context(), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(is_block_step.data()), - bsz, - block_num_per_seq, - block_size); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed."); + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &block_tables, + const paddle::Tensor &is_block_step, + const int block_size) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + int r = baidu::xpu::api::plugin::recover_decode_task( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + bsz, + block_num_per_seq, + block_size); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed."); } PD_BUILD_OP(recover_decode_task) diff --git a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc index 79a86af6b..0d41fa0d5 100644 --- a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc +++ b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc @@ -74,8 +74,8 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data( using type_meta_data = RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data; - // std::printf("#### open_shm_and_get_complete_signal_meta_data layer idx:%d, - // to ptx:%p \n", + // std::printf("#### open_shm_and_get_complete_signal_meta_data layer + // idx:%d, to ptx:%p \n", // -1, signal_ptr); type_meta_data meta_data(-1, signal_ptr, signal_shm_fd); @@ -102,8 +102,8 @@ void RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise( int32_t layer_id = meta_data_ptr[0]; int32_t* ptr = reinterpret_cast(meta_data_ptr[1]); *ptr = layer_id; - // std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, to - // ptx:%p \n", + // std::printf("#### save_cache_kv_complete_signal_layerwise layer idx:%d, + // to ptx:%p \n", // *ptr, meta_data_ptr[1]); } diff --git a/custom_ops/xpu_ops/src/ops/save_with_output_msg.cc b/custom_ops/xpu_ops/src/ops/save_with_output_msg.cc index fd132a775..7e1bb8815 100644 --- a/custom_ops/xpu_ops/src/ops/save_with_output_msg.cc +++ b/custom_ops/xpu_ops/src/ops/save_with_output_msg.cc @@ -12,114 +12,118 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" #include #include #include #include #include +#include "paddle/extension.h" #define MAX_BSZ 256 // #define SAVE_WITH_OUTPUT_DEBUG struct msgdata { - long mtype; - int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens + long mtype; + int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens }; // #define SAVE_WITH_OUTPUT_DEBUG -void SaveOutMmsg(const paddle::Tensor &x, const paddle::Tensor ¬_need_stop, - int64_t rank_id, int msg_queue_id, bool save_each_rank) { - if (!save_each_rank && rank_id > 0) { - return; - } - auto x_cpu = x.copy_to(paddle::CPUPlace(), false); - int64_t *x_data = x_cpu.data(); - static struct msgdata msg_sed; - - if (const char *inference_msg_queue_id_env_p = - std::getenv("INFERENCE_MSG_QUEUE_ID")) { - std::string inference_msg_queue_id_env_str( - inference_msg_queue_id_env_p); - int inference_msg_queue_id_from_env = - std::stoi(inference_msg_queue_id_env_str); - msg_queue_id = inference_msg_queue_id_from_env; -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " - << inference_msg_queue_id_from_env << std::endl; -#endif - } else { -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." - << std::endl; -#endif - } - int inference_msg_id_from_env = 1; - if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { - std::string inference_msg_id_env_str(inference_msg_id_env_p); - inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); - if (inference_msg_id_from_env == 2) { - // 2 and -2 is preserve for no-output indication. - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be 2, please use other number."); - } - if (inference_msg_id_from_env < 0) { - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be negative, please use other " - "number."); - } - -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env - << std::endl; -#endif - } else { -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout - << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." - << std::endl; -#endif - } - static key_t key = ftok("/dev/shm", msg_queue_id); - - static int msgid = msgget(key, IPC_CREAT | 0666); -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "save_output key: " << key << std::endl; - std::cout << "save_output msgid: " << msgid << std::endl; -#endif - msg_sed.mtype = 1; - bool not_need_stop_data = not_need_stop.data()[0]; - // printf("not_need_stop_data %d\n", (int)not_need_stop_data); - msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env - : -inference_msg_id_from_env; - int bsz = x.shape()[0]; - msg_sed.mtext[1] = bsz; - for (int i = 2; i < bsz + 2; i++) { - msg_sed.mtext[i] = (int)x_data[i - 2]; - } -#ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "save_output msg data: "; - for (int i = 0; i < bsz; i++) { - std::cout << " " << (int)x_data[i]; - } - std::cout << std::endl; -#endif - if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) { - printf("save_output full msg buffer\n"); - } +void SaveOutMmsg(const paddle::Tensor &x, + const paddle::Tensor ¬_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { return; + } + auto x_cpu = x.copy_to(paddle::CPUPlace(), false); + int64_t *x_data = x_cpu.data(); + static struct msgdata msg_sed; + + if (const char *inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char *inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is preserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } + +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = ftok("/dev/shm", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output key: " << key << std::endl; + std::cout << "save_output msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + bool not_need_stop_data = not_need_stop.data()[0]; + // printf("not_need_stop_data %d\n", (int)not_need_stop_data); + msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env + : -inference_msg_id_from_env; + int bsz = x.shape()[0]; + msg_sed.mtext[1] = bsz; + for (int i = 2; i < bsz + 2; i++) { + msg_sed.mtext[i] = (int)x_data[i - 2]; + } +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output msg data: "; + for (int i = 0; i < bsz; i++) { + std::cout << " " << (int)x_data[i]; + } + std::cout << std::endl; +#endif + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) { + printf("save_output full msg buffer\n"); + } + return; } void SaveOutMmsgStatic(const paddle::Tensor &x, - const paddle::Tensor ¬_need_stop, int64_t rank_id, + const paddle::Tensor ¬_need_stop, + int64_t rank_id, bool save_each_rank) { - SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank); + SaveOutMmsg(x, not_need_stop, rank_id, 1, save_each_rank); } void SaveOutMmsgDynamic(const paddle::Tensor &x, - const paddle::Tensor ¬_need_stop, int64_t rank_id, - int msg_queue_id, bool save_each_rank) { - SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank); + const paddle::Tensor ¬_need_stop, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank) { + SaveOutMmsg(x, not_need_stop, rank_id, msg_queue_id, save_each_rank); } PD_BUILD_OP(save_output) diff --git a/custom_ops/xpu_ops/src/ops/set_value_by_flags_and_idx.cc b/custom_ops/xpu_ops/src/ops/set_value_by_flags_and_idx.cc index e060f12a3..7d37fb2a7 100644 --- a/custom_ops/xpu_ops/src/ops/set_value_by_flags_and_idx.cc +++ b/custom_ops/xpu_ops/src/ops/set_value_by_flags_and_idx.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "xpu/plugin.h" -#include void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &input_ids, @@ -23,26 +23,35 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - std::vector pre_ids_all_shape = pre_ids_all.shape(); - int bs = seq_lens_this_time.shape()[0]; - int length = pre_ids_all.shape()[1]; - int length_input_ids = input_ids.shape()[1]; - int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx( - xpu_ctx->x_context(), stop_flags.data(), - const_cast(pre_ids_all.data()), - input_ids.data(), seq_lens_encoder.data(), - seq_lens_decoder.data(), step_idx.data(), bs, length, - length_input_ids); - PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed."); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + std::vector pre_ids_all_shape = pre_ids_all.shape(); + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all.shape()[1]; + int length_input_ids = input_ids.shape()[1]; + int r = baidu::xpu::api::plugin::set_value_by_flags_and_idx( + xpu_ctx->x_context(), + stop_flags.data(), + const_cast(pre_ids_all.data()), + input_ids.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + step_idx.data(), + bs, + length, + length_input_ids); + PD_CHECK(r == 0, "xpu::plugin::set_value_by_flags_and_idx failed."); } PD_BUILD_OP(set_value_by_flags_and_idx) - .Inputs({"pre_ids_all", "input_ids", "seq_lens_this_time", - "seq_lens_encoder", "seq_lens_decoder", "step_idx", "stop_flags"}) + .Inputs({"pre_ids_all", + "input_ids", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx", + "stop_flags"}) .Outputs({"pre_ids_all_out"}) .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) .SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdx)); diff --git a/custom_ops/xpu_ops/src/ops/share_external_data.cc b/custom_ops/xpu_ops/src/ops/share_external_data.cc index f4f40bebb..04f9f2ed9 100644 --- a/custom_ops/xpu_ops/src/ops/share_external_data.cc +++ b/custom_ops/xpu_ops/src/ops/share_external_data.cc @@ -30,8 +30,9 @@ std::vector ShareExternalData(const paddle::Tensor &input, void *data_ptr_addr = nullptr; if (use_ipc) { #if XPURT_VERSION_MAJOR == 5 - int ret = xpu_ipc_open_memhandle( - &data_ptr_addr, *(XPUIpcMemHandle *)&shm->memHandle, 0x01); // NOLINT + int ret = xpu_ipc_open_memhandle(&data_ptr_addr, + *(XPUIpcMemHandle *)&shm->memHandle, + 0x01); // NOLINT PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed"); #elif XPURT_VERSION_MAJOR == 4 PD_THROW("kl2 not support prefix cache"); diff --git a/custom_ops/xpu_ops/src/ops/step.cc b/custom_ops/xpu_ops/src/ops/step.cc index 196679d75..029edf0d0 100644 --- a/custom_ops/xpu_ops/src/ops/step.cc +++ b/custom_ops/xpu_ops/src/ops/step.cc @@ -12,82 +12,100 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -#include -void StepPaddle( - const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &ori_seq_lens_encoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list, - const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list, - const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list, - const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, const paddle::Tensor &free_list_len, - const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens, - const paddle::Tensor &first_token_ids, const int block_size, - const int encoder_decoder_block_num) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); +void StepPaddle(const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const int block_size, + const int encoder_decoder_block_num) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - const int bsz = seq_lens_this_time.shape()[0]; - PADDLE_ENFORCE_LE( - bsz, 640, - phi::errors::InvalidArgument( - "Only support bsz <= 640, but received bsz is %d", bsz)); - const int block_num_per_seq = block_tables.shape()[1]; - const int length = input_ids.shape()[1]; - const int pre_id_length = pre_ids.shape()[1]; - const int max_decoder_block_num = pre_id_length / block_size; - int r = baidu::xpu::api::plugin::free_and_dispatch_block( - xpu_ctx->x_context(), const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(encoder_block_lens.data()), - const_cast(is_block_step.data()), - const_cast(step_block_list.data()), - const_cast(step_lens.data()), + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::free_and_dispatch_block( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num); + PD_CHECK(r == 0, "free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::recover_block( + xpu_ctx->x_context(), const_cast(recover_block_list.data()), const_cast(recover_lens.data()), - const_cast(need_block_list.data()), - const_cast(need_block_len.data()), - const_cast(used_list_len.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + const_cast(seq_lens_encoder.data()), + seq_lens_decoder.data(), + const_cast(block_tables.data()), const_cast(free_list.data()), const_cast(free_list_len.data()), - const_cast(first_token_ids.data()), bsz, block_size, - block_num_per_seq, max_decoder_block_num); - PD_CHECK(r == 0, "free_and_dispatch_block failed."); - auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); - int recover_lens_cpu_data = recover_lens_cpu.data()[0]; - if (recover_lens_cpu_data > 0) { - r = baidu::xpu::api::plugin::recover_block( - xpu_ctx->x_context(), - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - ori_seq_lens_encoder.data(), - const_cast(seq_lens_encoder.data()), - seq_lens_decoder.data(), - const_cast(block_tables.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(input_ids.data()), - pre_ids.data(), step_idx.data(), - encoder_block_lens.data(), used_list_len.data(), - next_tokens.data(), first_token_ids.data(), bsz, - block_num_per_seq, length, pre_id_length); - PD_CHECK(r == 0, "recover_block failed."); - } + const_cast(input_ids.data()), + pre_ids.data(), + step_idx.data(), + encoder_block_lens.data(), + used_list_len.data(), + next_tokens.data(), + first_token_ids.data(), + bsz, + block_num_per_seq, + length, + pre_id_length); + PD_CHECK(r == 0, "recover_block failed."); + } } PD_BUILD_OP(step_paddle) @@ -114,13 +132,24 @@ PD_BUILD_OP(step_paddle) "next_tokens", "first_token_ids"}) .Attrs({"block_size: int", "encoder_decoder_block_num: int"}) - .Outputs({"stop_flags_out", "seq_lens_this_time_out", - "seq_lens_encoder_out", "seq_lens_decoder_out", - "block_tables_out", "encoder_block_lens_out", "is_block_step_out", - "step_block_list_out", "step_lens_out", "recover_block_list_out", - "recover_lens_out", "need_block_list_out", "need_block_len_out", - "used_list_len_out", "free_list_out", "free_list_len_out", - "input_ids_out", "first_token_ids_out"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) .SetInplaceMap({{"stop_flags", "stop_flags_out"}, {"seq_lens_this_time", "seq_lens_this_time_out"}, {"seq_lens_encoder", "seq_lens_encoder_out"}, diff --git a/custom_ops/xpu_ops/src/ops/stop_generation_multi_ends.cc b/custom_ops/xpu_ops/src/ops/stop_generation_multi_ends.cc index 72f7f9fb6..7043baa73 100644 --- a/custom_ops/xpu_ops/src/ops/stop_generation_multi_ends.cc +++ b/custom_ops/xpu_ops/src/ops/stop_generation_multi_ends.cc @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "xpu/plugin.h" -#include #include #include #include @@ -24,6 +21,9 @@ #include #include #include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &stop_flags, @@ -31,22 +31,25 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &end_ids, const paddle::Tensor &next_tokens, const bool beam_search) { - PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); - PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - std::vector shape = topk_ids.shape(); - int64_t bs_now = shape[0]; - int64_t end_length = end_ids.shape()[0]; - int r = baidu::xpu::api::plugin::set_stop_value_multi_ends( - xpu_ctx->x_context(), const_cast(stop_flags.data()), - const_cast(topk_ids.data()), - const_cast(next_tokens.data()), - end_ids.data(), seq_lens.data(), bs_now, end_length, - beam_search); - PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed."); + PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + int r = baidu::xpu::api::plugin::set_stop_value_multi_ends( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, + end_length, + beam_search); + PD_CHECK(r == 0, "xpu::plugin::set_stop_value_multi_ends failed."); } PD_BUILD_OP(set_stop_value_multi_ends) diff --git a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc index a702a465f..1df2ba82b 100644 --- a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc +++ b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc @@ -17,53 +17,49 @@ #include "paddle/extension.h" #include "xpu/plugin.h" -void TextImageGatherScatter( - paddle::Tensor& input, - paddle::Tensor& text_input, - paddle::Tensor& image_input, - paddle::Tensor& token_type_ids, - paddle::Tensor& text_index, - paddle::Tensor& image_index, - const bool is_scatter) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); +void TextImageGatherScatter(paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - const int64_t token_num = input.dims()[0]; - const int64_t hidden_size = input.dims()[1]; - const int64_t text_token_num = text_input.dims()[0]; - const int64_t image_token_num = image_input.dims()[0]; + const int64_t token_num = input.dims()[0]; + const int64_t hidden_size = input.dims()[1]; + const int64_t text_token_num = text_input.dims()[0]; + const int64_t image_token_num = image_input.dims()[0]; - switch (input.type()) { - case paddle::DataType::BFLOAT16: { - using XPUType = typename XPUTypeTrait::Type; - typedef paddle::bfloat16 data_t; - int r = baidu::xpu::api::plugin::text_image_gather_scatter( - xpu_ctx->x_context(), - reinterpret_cast(input.data()), - reinterpret_cast(text_input.data()), - reinterpret_cast(image_input.data()), - reinterpret_cast(token_type_ids.data()), - reinterpret_cast(text_index.data()), - reinterpret_cast(image_index.data()), - token_num, - text_token_num, - image_token_num, - hidden_size, - is_scatter - ); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter"); - break; - } - default: { - PD_THROW( - "NOT supported data type. Only support BFLOAT16. "); - break; - } + switch (input.type()) { + case paddle::DataType::BFLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 data_t; + int r = baidu::xpu::api::plugin::text_image_gather_scatter( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + reinterpret_cast(text_input.data()), + reinterpret_cast(image_input.data()), + reinterpret_cast(token_type_ids.data()), + reinterpret_cast(text_index.data()), + reinterpret_cast(image_index.data()), + token_num, + text_token_num, + image_token_num, + hidden_size, + is_scatter); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter"); + break; } + default: { + PD_THROW("NOT supported data type. Only support BFLOAT16. "); + break; + } + } } - PD_BUILD_OP(text_image_gather_scatter) .Inputs({"input", "text_input", diff --git a/custom_ops/xpu_ops/src/ops/text_image_index_out.cc b/custom_ops/xpu_ops/src/ops/text_image_index_out.cc index a0ce15036..10515614a 100644 --- a/custom_ops/xpu_ops/src/ops/text_image_index_out.cc +++ b/custom_ops/xpu_ops/src/ops/text_image_index_out.cc @@ -16,33 +16,30 @@ #include "paddle/extension.h" #include "xpu/plugin.h" -void TextImageIndexOut( - const paddle::Tensor& token_type_ids, - const paddle::Tensor& text_index, - const paddle::Tensor& image_index) { - if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type() - != paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) { - PD_THROW("NOT supported data type. Only support BFLOAT16. "); - } - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - const int64_t token_num = token_type_ids.shape()[0]; - int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(), - token_type_ids.data(), - const_cast(text_index.data()), - const_cast(image_index.data()), - token_num); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out"); +void TextImageIndexOut(const paddle::Tensor& token_type_ids, + const paddle::Tensor& text_index, + const paddle::Tensor& image_index) { + if (token_type_ids.type() != paddle::DataType::INT32 || + text_index.type() != paddle::DataType::INT32 || + image_index.type() != paddle::DataType::INT32) { + PD_THROW("NOT supported data type. Only support BFLOAT16. "); + } + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + const int64_t token_num = token_type_ids.shape()[0]; + int r = baidu::xpu::api::plugin::text_image_index_out( + xpu_ctx->x_context(), + token_type_ids.data(), + const_cast(text_index.data()), + const_cast(image_index.data()), + token_num); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out"); } - PD_BUILD_OP(text_image_index_out) - .Inputs({"token_type_ids", - "text_index", - "image_index"}) - .Outputs({"text_index_out", - "image_index_out"}) + .Inputs({"token_type_ids", "text_index", "image_index"}) + .Outputs({"text_index_out", "image_index_out"}) .SetInplaceMap({{"text_index", "text_index_out"}, {"image_index", "image_index_out"}}) .SetKernelFn(PD_KERNEL(TextImageIndexOut)); diff --git a/custom_ops/xpu_ops/src/ops/update_inputs.cc b/custom_ops/xpu_ops/src/ops/update_inputs.cc index 77b6cf9a1..53b057e30 100644 --- a/custom_ops/xpu_ops/src/ops/update_inputs.cc +++ b/custom_ops/xpu_ops/src/ops/update_inputs.cc @@ -26,40 +26,39 @@ void UpdateInputes(const paddle::Tensor &stop_flags, const paddle::Tensor &stop_nums, const paddle::Tensor &next_tokens, const paddle::Tensor &is_block_step) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - const int max_bsz = stop_flags.shape()[0]; - PADDLE_ENFORCE_LE( - max_bsz, - 1024, - phi::errors::InvalidArgument( - "Only support max_bs <= 1024, but received max_bs is %d", max_bsz)); - const int now_bsz = seq_lens_this_time.shape()[0]; - const int input_ids_stride = input_ids.shape()[1]; - auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false); + const int max_bsz = stop_flags.shape()[0]; + PADDLE_ENFORCE_LE( + max_bsz, + 1024, + phi::errors::InvalidArgument( + "Only support max_bs <= 1024, but received max_bs is %d", max_bsz)); + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + auto not_need_stop_xpu = not_need_stop.copy_to(stop_flags.place(), false); - int r = baidu::xpu::api::plugin::update_inputs( - xpu_ctx->x_context(), - const_cast(not_need_stop_xpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), - stop_nums.data(), - stop_flags.data(), - is_block_step.data(), - next_tokens.data(), - now_bsz, - max_bsz, - input_ids_stride); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed."); - auto not_need_stop_cpu = - not_need_stop_xpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + int r = baidu::xpu::api::plugin::update_inputs( + xpu_ctx->x_context(), + const_cast(not_need_stop_xpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs failed."); + auto not_need_stop_cpu = + not_need_stop_xpu.copy_to(not_need_stop.place(), false); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_OP(update_inputs) diff --git a/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc index 50dc8d748..9e77e636f 100644 --- a/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc +++ b/custom_ops/xpu_ops/src/ops/update_inputs_v1.cc @@ -18,55 +18,54 @@ #include "xpu/plugin.h" void UpdateInputesV1(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_seq_lens_decoder, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &topk_ids, - const paddle::Tensor &input_ids, - const paddle::Tensor &block_tables, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step, - const int block_size) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + const paddle::Tensor ¬_need_stop, // only on cpu + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &topk_ids, + const paddle::Tensor &input_ids, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step, + const int block_size) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - const int max_bsz = stop_flags.shape()[0]; - const int now_bsz = seq_lens_this_time.shape()[0]; - // std::cout << "now_bsz: " << now_bsz << std::endl; - const int input_ids_stride = input_ids.shape()[1]; - const int block_num_per_seq = block_tables.shape()[1]; - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - int r = baidu::xpu::api::plugin::update_inputs_v1( - xpu_ctx->x_context(), - const_cast(not_need_stop_gpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(prompt_lens.data()), - const_cast(topk_ids.data()), - const_cast(input_ids.data()), - const_cast(block_tables.data()), - stop_nums.data(), - const_cast(stop_flags.data()), - const_cast(is_block_step.data()), - next_tokens.data(), - now_bsz, - max_bsz, - input_ids_stride, - block_num_per_seq, - block_size); - PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed."); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + // std::cout << "now_bsz: " << now_bsz << std::endl; + const int input_ids_stride = input_ids.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + int r = baidu::xpu::api::plugin::update_inputs_v1( + xpu_ctx->x_context(), + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(prompt_lens.data()), + const_cast(topk_ids.data()), + const_cast(input_ids.data()), + const_cast(block_tables.data()), + stop_nums.data(), + const_cast(stop_flags.data()), + const_cast(is_block_step.data()), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed."); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_OP(update_inputs_v1) diff --git a/custom_ops/xpu_ops/src/ops/utility/debug.h b/custom_ops/xpu_ops/src/ops/utility/debug.h old mode 100755 new mode 100644 diff --git a/custom_ops/xpu_ops/src/ops/utility/helper.h b/custom_ops/xpu_ops/src/ops/utility/helper.h index 85335682d..5f3e23689 100644 --- a/custom_ops/xpu_ops/src/ops/utility/helper.h +++ b/custom_ops/xpu_ops/src/ops/utility/helper.h @@ -15,8 +15,6 @@ #pragma once #include -#include -#include #include #include #include @@ -24,47 +22,56 @@ #include #include #include +#include +#include +#include #include "paddle/extension.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "xpu/plugin.h" -#include -template class PDTraits; +template +class PDTraits; -template <> class PDTraits { - public: - typedef float DataType; - typedef float data_t; +template <> +class PDTraits { + public: + typedef float DataType; + typedef float data_t; }; -template <> class PDTraits { - public: - typedef float16 DataType; - typedef paddle::float16 data_t; +template <> +class PDTraits { + public: + typedef float16 DataType; + typedef paddle::float16 data_t; }; -template <> class PDTraits { - public: - typedef bfloat16 DataType; - typedef paddle::bfloat16 data_t; +template <> +class PDTraits { + public: + typedef bfloat16 DataType; + typedef paddle::bfloat16 data_t; }; -template <> class PDTraits { - public: - typedef int8_t DataType; - typedef int8_t data_t; +template <> +class PDTraits { + public: + typedef int8_t DataType; + typedef int8_t data_t; }; -template <> class PDTraits { - public: - typedef uint8_t DataType; - typedef uint8_t data_t; +template <> +class PDTraits { + public: + typedef uint8_t DataType; + typedef uint8_t data_t; }; -template <> class PDTraits { - public: - typedef int64_t DataType; - typedef int64_t data_t; +template <> +class PDTraits { + public: + typedef int64_t DataType; + typedef int64_t data_t; }; diff --git a/custom_ops/xpu_ops/src/ops/weight_quantize_xpu.cc b/custom_ops/xpu_ops/src/ops/weight_quantize_xpu.cc index a660f9a77..4de303b80 100644 --- a/custom_ops/xpu_ops/src/ops/weight_quantize_xpu.cc +++ b/custom_ops/xpu_ops/src/ops/weight_quantize_xpu.cc @@ -11,110 +11,124 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" #include #include #include #include +#include "xpu/plugin.h" template -std::vector -WeightQuantizeKernel(const paddle::Tensor &x, const std::string &algo, - const int32_t arch, const int32_t group_size) { - using XPUType = typename XPUTypeTrait::Type; - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - int64_t k = x.shape()[0]; - int64_t n = x.shape()[1]; +std::vector WeightQuantizeKernel(const paddle::Tensor &x, + const std::string &algo, + const int32_t arch, + const int32_t group_size) { + using XPUType = typename XPUTypeTrait::Type; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + int64_t k = x.shape()[0]; + int64_t n = x.shape()[1]; - paddle::Tensor scale = - paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place()); - if (algo == "weight_only_int8") { - paddle::Tensor out = - paddle::full({k, n}, 0, paddle::DataType::INT8, x.place()); - int ret = baidu::xpu::api::plugin::quant2d_per_channel( + paddle::Tensor scale = + paddle::full({n}, 0, paddle::DataType::FLOAT32, x.place()); + if (algo == "weight_only_int8") { + paddle::Tensor out = + paddle::full({k, n}, 0, paddle::DataType::INT8, x.place()); + int ret = + baidu::xpu::api::plugin::quant2d_per_channel( xpu_ctx->x_context(), - reinterpret_cast(x.template data()), nullptr, - out.data(), scale.data(), k, n); - PD_CHECK(ret == 0); - return {out, scale}; - } else if (algo == "weight_only_int4") { - // TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose + - // quant2d_per_token + transpose at now - PD_CHECK(k % 2 == 0); - paddle::Tensor out = paddle::full({(k + 1) / 2, n}, 0, - paddle::DataType::INT8, x.place()); - xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context()); - XPUType *x_trans = RAII_GUARD.alloc(k * n); - int8_t *out_trans = RAII_GUARD.alloc(k * n / 2); - PD_CHECK(x_trans != nullptr); - PD_CHECK(out_trans != nullptr); - int ret = baidu::xpu::api::transpose( - xpu_ctx->x_context(), - reinterpret_cast(x.data()), x_trans, {k, n}, - {1, 0}); - PD_CHECK(ret == 0); - ret = infer_ops::quant2d_per_token( - xpu_ctx->x_context(), x_trans, nullptr, - reinterpret_cast(out_trans), scale.data(), n, k); - PD_CHECK(ret == 0); - ret = baidu::xpu::api::transpose(xpu_ctx->x_context(), - out_trans, out.data(), - {n, k / 2}, {1, 0}); - PD_CHECK(ret == 0); - return {out, scale}; - } else { - PD_THROW("Weight quantize only supports weight_only_int8 on XPU now."); - return {}; - } + reinterpret_cast(x.template data()), + nullptr, + out.data(), + scale.data(), + k, + n); + PD_CHECK(ret == 0); + return {out, scale}; + } else if (algo == "weight_only_int4") { + // TODO(mayang02): fix quant2d_per_channel int4 bugs, use transpose + + // quant2d_per_token + transpose at now + PD_CHECK(k % 2 == 0); + paddle::Tensor out = + paddle::full({(k + 1) / 2, n}, 0, paddle::DataType::INT8, x.place()); + xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context()); + XPUType *x_trans = RAII_GUARD.alloc(k * n); + int8_t *out_trans = RAII_GUARD.alloc(k * n / 2); + PD_CHECK(x_trans != nullptr); + PD_CHECK(out_trans != nullptr); + int ret = baidu::xpu::api::transpose( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + x_trans, + {k, n}, + {1, 0}); + PD_CHECK(ret == 0); + ret = infer_ops::quant2d_per_token( + xpu_ctx->x_context(), + x_trans, + nullptr, + reinterpret_cast(out_trans), + scale.data(), + n, + k); + PD_CHECK(ret == 0); + ret = baidu::xpu::api::transpose(xpu_ctx->x_context(), + out_trans, + out.data(), + {n, k / 2}, + {1, 0}); + PD_CHECK(ret == 0); + return {out, scale}; + } else { + PD_THROW("Weight quantize only supports weight_only_int8 on XPU now."); + return {}; + } } std::vector WeightQuantize(const paddle::Tensor &x, const std::string &algo, const int32_t arch, const int32_t group_size) { - const auto x_type = x.dtype(); -#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \ - return WeightQuantizeKernel(x, algo, arch, group_size); + const auto x_type = x.dtype(); +#define APPLY_WEIGHT_QUANTIZE_KERNEL(TX) \ + return WeightQuantizeKernel(x, algo, arch, group_size); - if (x_type == paddle::DataType::BFLOAT16) { - APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16); - } else if (x_type == paddle::DataType::FLOAT32) { - APPLY_WEIGHT_QUANTIZE_KERNEL(float); - } else { - PD_THROW("WeightQuantize not support x_type==%d", - static_cast(x_type)); - return {}; - } + if (x_type == paddle::DataType::BFLOAT16) { + APPLY_WEIGHT_QUANTIZE_KERNEL(paddle::bfloat16); + } else if (x_type == paddle::DataType::FLOAT32) { + APPLY_WEIGHT_QUANTIZE_KERNEL(float); + } else { + PD_THROW("WeightQuantize not support x_type==%d", static_cast(x_type)); + return {}; + } } -std::vector> -WeightQuantizeInferShape(const std::vector &x_shape, - const std::string &algo, const int32_t arch, - const int32_t group_size) { - if (algo == "weight_only_int8") { - return {x_shape, {x_shape[1]}}; - } else if (algo == "weight_only_int4") { - return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}}; - } else { - PD_THROW("weight_quantize not support algo=%s", algo); - } +std::vector> WeightQuantizeInferShape( + const std::vector &x_shape, + const std::string &algo, + const int32_t arch, + const int32_t group_size) { + if (algo == "weight_only_int8") { + return {x_shape, {x_shape[1]}}; + } else if (algo == "weight_only_int4") { + return {{x_shape[0] / 2, x_shape[1]}, {x_shape[1]}}; + } else { + PD_THROW("weight_quantize not support algo=%s", algo); + } } -std::vector -WeightQuantizeInferDtype(const paddle::DataType &x_dtype, - const std::string &algo, const int32_t arch, - const int32_t group_size) { - if (algo == "weight_only_int8") { - return {paddle::DataType::INT8, paddle::DataType::FLOAT32}; - } else if (algo == "weight_only_int4") { - return {paddle::DataType::INT8, paddle::DataType::FLOAT32}; - } else { - PD_THROW("weight_quantize not support algo=%s", algo); - } +std::vector WeightQuantizeInferDtype( + const paddle::DataType &x_dtype, + const std::string &algo, + const int32_t arch, + const int32_t group_size) { + if (algo == "weight_only_int8") { + return {paddle::DataType::INT8, paddle::DataType::FLOAT32}; + } else if (algo == "weight_only_int4") { + return {paddle::DataType::INT8, paddle::DataType::FLOAT32}; + } else { + PD_THROW("weight_quantize not support algo=%s", algo); + } } PD_BUILD_OP(weight_quantize_xpu) diff --git a/custom_ops/xpu_ops/src/ops/xpu_multiprocess.h b/custom_ops/xpu_ops/src/ops/xpu_multiprocess.h index bde5fe4aa..430b559d8 100644 --- a/custom_ops/xpu_ops/src/ops/xpu_multiprocess.h +++ b/custom_ops/xpu_ops/src/ops/xpu_multiprocess.h @@ -24,60 +24,60 @@ #include #include #include -#include #include #include +#include struct shmStruct { - size_t nprocesses; + size_t nprocesses; #if XPURT_VERSION_MAJOR == 5 - XPUIpcMemHandle memHandle; + XPUIpcMemHandle memHandle; #endif - uint64_t data_ptr_addr; + uint64_t data_ptr_addr; }; struct sharedMemoryInfo { - void *addr; - size_t size; - int shmFd; + void *addr; + size_t size; + int shmFd; }; -static int sharedMemoryCreate(const char *name, size_t sz, +static int sharedMemoryCreate(const char *name, + size_t sz, sharedMemoryInfo *info) { - info->size = sz; + info->size = sz; - info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777); - PD_CHECK(info->shmFd >= 0, "shm_open failed"); + info->shmFd = shm_open(name, O_RDWR | O_CREAT, 0777); + PD_CHECK(info->shmFd >= 0, "shm_open failed"); - int status = ftruncate(info->shmFd, sz); - PD_CHECK(status == 0, "ftruncate failed"); + int status = ftruncate(info->shmFd, sz); + PD_CHECK(status == 0, "ftruncate failed"); - info->addr = - mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); - PD_CHECK(info->addr != NULL, "mmap failed"); + info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); + PD_CHECK(info->addr != NULL, "mmap failed"); - return 0; + return 0; } -static int sharedMemoryOpen(const char *name, size_t sz, +static int sharedMemoryOpen(const char *name, + size_t sz, sharedMemoryInfo *info) { - info->size = sz; + info->size = sz; - info->shmFd = shm_open(name, O_RDWR, 0777); - PD_CHECK(info->shmFd >= 0, "shm_open failed"); + info->shmFd = shm_open(name, O_RDWR, 0777); + PD_CHECK(info->shmFd >= 0, "shm_open failed"); - info->addr = - mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); - PD_CHECK(info->addr != nullptr, "mmap failed"); + info->addr = mmap(0, sz, PROT_READ | PROT_WRITE, MAP_SHARED, info->shmFd, 0); + PD_CHECK(info->addr != nullptr, "mmap failed"); - return 0; + return 0; } static void sharedMemoryClose(sharedMemoryInfo *info) { - if (info->addr) { - munmap(info->addr, info->size); - } - if (info->shmFd) { - close(info->shmFd); - } + if (info->addr) { + munmap(info->addr, info->size); + } + if (info->shmFd) { + close(info->shmFd); + } } diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 5ce255956..a399e5315 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -24,121 +24,176 @@ namespace api { namespace plugin { template -DLL_EXPORT int set_stop_value_multi_ends(Context *ctx, bool *stop_flags, - T *topk_ids, T *next_tokens, - const T *end_ids, const int *seq_lens, - const int bs, const int end_length, +DLL_EXPORT int set_stop_value_multi_ends(Context* ctx, + bool* stop_flags, + T* topk_ids, + T* next_tokens, + const T* end_ids, + const int* seq_lens, + const int bs, + const int end_length, const bool beam_search); -DLL_EXPORT int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags, - int64_t *pre_ids_all, - const int64_t *input_ids, - const int *seq_lens_encoder, - const int *seq_lens_decoder, - const int64_t *step_idx, int bs, - int length, int length_input_ids); +DLL_EXPORT int set_value_by_flags_and_idx(Context* ctx, + const bool* stop_flags, + int64_t* pre_ids_all, + const int64_t* input_ids, + const int* seq_lens_encoder, + const int* seq_lens_decoder, + const int64_t* step_idx, + int bs, + int length, + int length_input_ids); template -DLL_EXPORT int token_penalty_multi_scores( - Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores, - const T *frequency_scores, const T *presence_scores, - const float *temperatures, const int64_t *cur_len, const int64_t *min_len, - const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs, - const int64_t length, const int64_t length_id, const int64_t end_length, - const int64_t length_bad_words); +DLL_EXPORT int token_penalty_multi_scores(Context* ctx, + const int64_t* pre_ids, + T* logits, + const T* penalty_scores, + const T* frequency_scores, + const T* presence_scores, + const float* temperatures, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int64_t* bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words); -DLL_EXPORT int get_padding_offset(Context *ctx, int *padding_offset, - int *cum_offsets_out, int *cu_seqlens_q, - int *cu_seqlens_k, int64_t *x_remove_padding, - const int64_t *input_ids, - const int *cum_offsets, const int *seq_lens, - const int max_seq_len, const int bs); +DLL_EXPORT int get_padding_offset(Context* ctx, + int* padding_offset, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + int64_t* x_remove_padding, + const int64_t* input_ids, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + const int bs); -DLL_EXPORT int update_inputs(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, int *seq_lens_encoder, - int *seq_lens_decoder, int64_t *input_ids, - const int64_t *stop_nums, const bool *stop_flags, - const bool *is_block_step, - const int64_t *next_tokens, const int bsz, - const int max_bsz, const int input_ids_stride); +DLL_EXPORT int update_inputs(Context* ctx, + bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* input_ids, + const int64_t* stop_nums, + const bool* stop_flags, + const bool* is_block_step, + const int64_t* next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride); -DLL_EXPORT int free_and_dispatch_block( - Context *ctx, bool *stop_flags, int *seq_lens_this_time, - int *seq_lens_decoder, int *block_tables, int *encoder_block_lens, - bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz, - const int block_size, const int block_num_per_seq, - const int max_decoder_block_num); +DLL_EXPORT int free_and_dispatch_block(Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + bool* is_block_step, + int* step_block_list, // [bsz] + int* step_len, + int* recover_block_list, + int* recover_len, + int* need_block_list, + int* need_block_len, + int* used_list_len, + int* free_list, + int* free_list_len, + int64_t* first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num); -DLL_EXPORT int -recover_block(Context *ctx, - int *recover_block_list, // [bsz] - int *recover_len, bool *stop_flags, int *seq_lens_this_time, - const int *ori_seq_lens_encoder, int *seq_lens_encoder, - const int *seq_lens_decoder, int *block_tables, int *free_list, - int *free_list_len, int64_t *input_ids, const int64_t *pre_ids, - const int64_t *step_idx, const int *encoder_block_lens, - const int *used_list_len, const int64_t *next_tokens, - const int64_t *first_token_ids, const int bsz, - const int block_num_per_seq, const int length, - const int pre_id_length); +DLL_EXPORT int recover_block(Context* ctx, + int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); - -DLL_EXPORT int -recover_decode_task(Context *ctx, bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, +DLL_EXPORT int recover_decode_task(Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int* block_tables, + bool* is_block_step, const int bsz, const int block_num_per_seq, const int block_size); -DLL_EXPORT int -update_inputs_v1(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size); +DLL_EXPORT int update_inputs_v1(Context* ctx, + bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size); template -DLL_EXPORT int -eb_adjust_batch(Context *ctx, const TX *x, TY *y, - VectorParam &encoder_seqs_lods, // NOLINT - VectorParam &encoder_batch_map, // NOLINT - VectorParam &decoder_batch_map, // NOLINT - int64_t hidden_dim); +DLL_EXPORT int eb_adjust_batch( + Context* ctx, + const TX* x, + TY* y, + VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& encoder_batch_map, // NOLINT + VectorParam& decoder_batch_map, // NOLINT + int64_t hidden_dim); template -DLL_EXPORT int -eb_gather_next_token(Context *ctx, const TX *x, TY *y, - VectorParam &encoder_seqs_lods, // NOLINT - VectorParam &encoder_batch_map, // NOLINT - VectorParam &decoder_batch_map, // NOLINT - int64_t hidden_dim); +DLL_EXPORT int eb_gather_next_token( + Context* ctx, + const TX* x, + TY* y, + VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& encoder_batch_map, // NOLINT + VectorParam& decoder_batch_map, // NOLINT + int64_t hidden_dim); template -DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x, - const TSCALE *scale_in, TY *y, - TSCALE *scale_out, int64_t m, int64_t n); +DLL_EXPORT int quant2d_per_channel(api::Context* ctx, + const TX* x, + const TSCALE* scale_in, + TY* y, + TSCALE* scale_out, + int64_t m, + int64_t n); DLL_EXPORT int text_image_index_out(Context* ctx, const int* token_type_ids, // x @@ -160,7 +215,8 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx, int64_t hidden_size, bool is_scatter); -/*--------------------------------------- MTP being --------------------------------------------*/ +/*--------------------------------------- MTP being + * --------------------------------------------*/ template DLL_EXPORT int speculate_token_penalty_multi_scores( @@ -200,7 +256,6 @@ DLL_EXPORT int mtp_free_and_dispatch_block(Context* ctx, const int block_num_per_seq, const int max_draft_tokens); - template DLL_EXPORT int speculate_verify(Context* ctx, int64_t* accept_tokens, @@ -457,9 +512,10 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx, T* output, int dim_embed, int elem_cnt); -/*--------------------------------------- MTP end --------------------------------------------*/ +/*--------------------------------------- MTP end + * --------------------------------------------*/ -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu index 5416b0045..d53a87b3e 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu @@ -36,8 +36,8 @@ __global__ void get_padding_offset(int *batch_id_per_token, } mfence_lm(); LM2GM(batch_id_per_token_lm, - batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j, - cur_len * sizeof(int)); + batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j, + cur_len * sizeof(int)); } if (cid == 0) { int cum_seq_len = (i + 1) * max_seq_len - cum_offsets_lm[1]; diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu index a098a0e6f..d63a249bf 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/rebuild_append_padding.xpu @@ -15,72 +15,72 @@ __global__ void RebuildAppendPaddingKernel(const T *full_hidden_states, int dim_embed, int elem_nums, T *out) { - int ncores = core_num(); - int cid = core_id(); - int tid = cid * cluster_num() + cluster_id(); - int nthreads = cluster_num() * ncores; - int64_t mstart = -1; - int64_t mend = -1; - int64_t nstart = -1; - int64_t nend = -1; - partition2d(tid, - nthreads, - elem_nums / dim_embed, - dim_embed, - &mstart, - &mend, - &nstart, - &nend); + int ncores = core_num(); + int cid = core_id(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + int64_t mstart = -1; + int64_t mend = -1; + int64_t nstart = -1; + int64_t nend = -1; + partition2d(tid, + nthreads, + elem_nums / dim_embed, + dim_embed, + &mstart, + &mend, + &nstart, + &nend); - const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64); - __simd__ T lm_full_hidden_states[BUFFER_LEN]; - int output_padding_offset_val, cum_offset_val, seq_len_encoder_val, - seq_len_decoder_val; + const int64_t BUFFER_LEN = rounddown(6144 / sizeof(T), 64); + __simd__ T lm_full_hidden_states[BUFFER_LEN]; + int output_padding_offset_val, cum_offset_val, seq_len_encoder_val, + seq_len_decoder_val; - for (int64_t _m = mstart; _m < mend; _m++) { - int out_token_id = _m; - GM2LM(output_padding_offset + out_token_id, - &output_padding_offset_val, - sizeof(int)); - int ori_token_id = out_token_id + output_padding_offset_val; - int bi = ori_token_id / max_seq_len; - GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int)); - GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int)); - int seq_id = 0; - if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) { - continue; - } else if (seq_len_encoder_val != 0) { - seq_id = seq_len_encoder_val - 1; - } - GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int)); - int input_token_id = ori_token_id - cum_offset_val + seq_id; - for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) { - int64_t read_size = min(BUFFER_LEN, nend - _n); - // out[i] = full_hidden_states[(i / dim_embed + - // output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed - // + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id) - // * dim_embed + i % dim_embed] - GM2LM(full_hidden_states + input_token_id * dim_embed + _n, - lm_full_hidden_states, - read_size * sizeof(T)); - LM2GM(lm_full_hidden_states, - out + _m * dim_embed + _n, - read_size * sizeof(T)); - } + for (int64_t _m = mstart; _m < mend; _m++) { + int out_token_id = _m; + GM2LM(output_padding_offset + out_token_id, + &output_padding_offset_val, + sizeof(int)); + int ori_token_id = out_token_id + output_padding_offset_val; + int bi = ori_token_id / max_seq_len; + GM2LM_ASYNC(seq_len_encoder + bi, &seq_len_encoder_val, sizeof(int)); + GM2LM(seq_len_decoder + bi, &seq_len_decoder_val, sizeof(int)); + int seq_id = 0; + if (seq_len_encoder_val == 0 and seq_len_decoder_val == 0) { + continue; + } else if (seq_len_encoder_val != 0) { + seq_id = seq_len_encoder_val - 1; } + GM2LM(cum_offset + bi, &cum_offset_val, sizeof(int)); + int input_token_id = ori_token_id - cum_offset_val + seq_id; + for (int64_t _n = nstart; _n < nend; _n += BUFFER_LEN) { + int64_t read_size = min(BUFFER_LEN, nend - _n); + // out[i] = full_hidden_states[(i / dim_embed + + // output_padding_offset[i / dim_embed] - cum_offset[(i / dim_embed + // + output_padding_offset[i / dim_embed]) / max_seq_len] + seq_id) + // * dim_embed + i % dim_embed] + GM2LM(full_hidden_states + input_token_id * dim_embed + _n, + lm_full_hidden_states, + read_size * sizeof(T)); + LM2GM(lm_full_hidden_states, + out + _m * dim_embed + _n, + read_size * sizeof(T)); + } + } } -#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \ - template __global__ void RebuildAppendPaddingKernel( \ - const T *full_hidden_states, \ - const int *cum_offset, \ - const int *seq_len_encoder, \ - const int *seq_len_decoder, \ - const int *output_padding_offset, \ - int max_seq_len, \ - int dim_embed, \ - int elem_nums, \ - T *out); +#define _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(T) \ + template __global__ void RebuildAppendPaddingKernel( \ + const T *full_hidden_states, \ + const int *cum_offset, \ + const int *seq_len_encoder, \ + const int *seq_len_decoder, \ + const int *output_padding_offset, \ + int max_seq_len, \ + int dim_embed, \ + int elem_nums, \ + T *out); _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(bfloat16); _XPU_DEF_REBUILD_APPEND_PADDING_KERNEL(float16); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu index 4f42fd69f..de24d1835 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_repeat_times.xpu @@ -152,8 +152,8 @@ __device__ void speculate_update_repeat_times_optimized( repeat_times_read_size_per_core * sizeof(int)); } sync_all(); - // each core loads pre_ids step by step and record the index of pre_ids - // which is less than zero, and store the index to boundary + // each core loads pre_ids step by step and record the index of + // pre_ids which is less than zero, and store the index to boundary if (repeat_times_start == 0) { bool do_prone = false; int64_t j = cid * pre_ids_lm_len; @@ -190,8 +190,8 @@ __device__ void speculate_update_repeat_times_optimized( buffer_ptr_pre_ids.toggle(); } } - // each core loads all the needed pre_ids into lm without mfence in between - // according to the index recorded by previous iteration + // each core loads all the needed pre_ids into lm without mfence in + // between according to the index recorded by previous iteration else { int cnt = -1; int64_t pre_ids_read_size = 0; diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu index f685fcf9e..51f2964e9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_update_value_by_repeat_times.xpu @@ -240,18 +240,20 @@ __global__ void speculate_update_value_by_repeat_times_simd( alpha, logits_, logits_, - (time_mask & - ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit + (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do + // alpha * logit logits_ = svmul_float32x16_mh( 1.0f / alpha, logits_, logits_, (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha - logits_ = vvsub_float32x16_mh( - logits_, time_, logits_, time_mask); // when time != 0, do logit = - // logit - time * beta - gamma; - logits_ = - svmul_float32x16(1.0f / temperature, logits_); // logit / temperature + logits_ = vvsub_float32x16_mh(logits_, + time_, + logits_, + time_mask); // when time != 0, do logit = + // logit - time * beta - gamma; + logits_ = svmul_float32x16(1.0f / temperature, + logits_); // logit / temperature vstore_lm_float32x16(logits_lm + j, logits_); } mfence_lm(); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu index db6efb4c7..b4be6ad5a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/recover_decode_task.xpu @@ -6,15 +6,15 @@ namespace xpu3 { namespace plugin { __global__ void recover_decode_task(bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size) { + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); @@ -23,15 +23,17 @@ __global__ void recover_decode_task(bool *stop_flags, int nthreads = nclusters * ncores; // if (clusterid != 0) return; for (; thread_idx < bsz; thread_idx += nthreads) { - if(is_block_step[thread_idx] == true) { - // int *block_table_now = block_tables + thread_idx * block_num_per_seq; - if (block_tables[thread_idx * block_num_per_seq + step_seq_lens_decoder[thread_idx] / block_size] != -1) { - // can be recovered for decoding - is_block_step[thread_idx] = false; - seq_lens_this_time[thread_idx]= 1; - stop_flags[thread_idx] = false; - seq_lens_encoder[thread_idx] = 0; - seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + if (is_block_step[thread_idx] == true) { + // int *block_table_now = block_tables + thread_idx * + // block_num_per_seq; + if (block_tables[thread_idx * block_num_per_seq + + step_seq_lens_decoder[thread_idx] / block_size] != -1) { + // can be recovered for decoding + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx] = 1; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; } } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/remove_padding.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/remove_padding.xpu index e2f4048bd..a28dfa934 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/remove_padding.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/remove_padding.xpu @@ -30,8 +30,8 @@ __global__ void remove_padding(int64_t *x_remove_padding, input_lm, sizeof(int64_t) * cur_len); LM2GM(input_lm, - x_remove_padding + i * sequence_length - cum_offset_lm + j, - sizeof(int64_t) * cur_len); + x_remove_padding + i * sequence_length - cum_offset_lm + j, + sizeof(int64_t) * cur_len); } } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_stop_value_multi_ends.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_stop_value_multi_ends.xpu index f3205c100..178a3c473 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_stop_value_multi_ends.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/set_stop_value_multi_ends.xpu @@ -54,14 +54,13 @@ __global__ void set_stop_value_multi_ends(bool* stop_flags, GM2LM_ASYNC(seq_lens + i, seq_lens_lm, sizeof(int) * readlen); mfence(); for (int j = 0; j < readlen; j++) { - if(prefill_one_step_stop){ + if (prefill_one_step_stop) { stop_flags_lm[j] = true; if (seq_lens_lm[j] == 0) { topk_ids_lm[j] = -1; } next_tokens_lm[j] = topk_ids_lm[j]; - } - else{ + } else { if (stop_flags_lm[j]) { if (seq_lens_lm[j] == 0) { topk_ids_lm[j] = -1; diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu index 608cda1c6..777af5491 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu @@ -8,166 +8,206 @@ namespace plugin { template static __device__ inline void text_image_gather( - __global_ptr__ T* input, - __global_ptr__ T* text_input, - __global_ptr__ T* image_input, - __global_ptr__ int* token_type_ids, - __global_ptr__ int* text_index, - __global_ptr__ int* image_index, - int64_t token_num, - int64_t text_token_num, - int64_t image_token_num, - int64_t hidden_size, - T* input_lm) { - int cid = core_id(); - int clusterid = cluster_id(); - int token_start_cluster; - int token_end_cluster; - int token_start_core; - int token_end_core; + __global_ptr__ T* input, + __global_ptr__ T* text_input, + __global_ptr__ T* image_input, + __global_ptr__ int* token_type_ids, + __global_ptr__ int* text_index, + __global_ptr__ int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + T* input_lm) { + int cid = core_id(); + int clusterid = cluster_id(); + int token_start_cluster; + int token_end_cluster; + int token_start_core; + int token_end_core; - const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 - // cluster partition - partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster); - if (token_start_cluster >= token_end_cluster) { - return; + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + // cluster partition + partition(cluster_id(), + cluster_num(), + (int)token_num, + 1, + &token_start_cluster, + &token_end_cluster); + if (token_start_cluster >= token_end_cluster) { + return; + } + int rows_cluster = + token_end_cluster - token_start_cluster; // total rows for a cluster + // core partition + partition(core_id(), + core_num(), + rows_cluster, + 1, + &token_start_core, + &token_end_core); + int rows_core = token_end_core - token_start_core; // total rows for a core + token_start_core += token_start_cluster; + token_end_core += token_start_cluster; + + int read_len; + for (int i = token_start_core; i < token_end_core; i += 1) { + int token_type, text_image_token_idx; + __global_ptr__ T* text_image_input = nullptr; + __global_ptr__ int* text_image_index = nullptr; + + GM2LM(token_type_ids + i, &token_type, sizeof(int)); + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { + text_image_input = image_input; + text_image_index = image_index; } - int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster - // core partition - partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core); - int rows_core = token_end_core - token_start_core; // total rows for a core - token_start_core += token_start_cluster; - token_end_core += token_start_cluster; + GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; - int read_len; - for (int i = token_start_core; i < token_end_core; i += 1) { - int token_type, text_image_token_idx; - __global_ptr__ T* text_image_input = nullptr; - __global_ptr__ int* text_image_index = nullptr; - - GM2LM(token_type_ids + i, &token_type, sizeof(int)); - if (token_type == 0) { - text_image_input = text_input; - text_image_index = text_index; - } else { - text_image_input = image_input; - text_image_index = image_index; - } - GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); - int input_offset = i * hidden_size; - int text_image_offset = text_image_token_idx * hidden_size; - - for (int j = 0; j < hidden_size; j += BUFSIZE) { - read_len = min(hidden_size - j, BUFSIZE); - GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len); - LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len); - } + for (int j = 0; j < hidden_size; j += BUFSIZE) { + read_len = min(hidden_size - j, BUFSIZE); + GM2LM(text_image_input + text_image_offset + j, + input_lm, + sizeof(T) * read_len); + LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len); } + } } template static __device__ inline void text_image_scatter( - __global_ptr__ T* input, - __global_ptr__ T* text_input, - __global_ptr__ T* image_input, - __global_ptr__ int* token_type_ids, - __global_ptr__ int* text_index, - __global_ptr__ int* image_index, - int64_t token_num, - int64_t text_token_num, - int64_t image_token_num, - int64_t hidden_size, - T* input_lm) { - int cid = core_id(); - int clusterid = cluster_id(); - int token_start_cluster; - int token_end_cluster; - int token_start_core; - int token_end_core; + __global_ptr__ T* input, + __global_ptr__ T* text_input, + __global_ptr__ T* image_input, + __global_ptr__ int* token_type_ids, + __global_ptr__ int* text_index, + __global_ptr__ int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + T* input_lm) { + int cid = core_id(); + int clusterid = cluster_id(); + int token_start_cluster; + int token_end_cluster; + int token_start_core; + int token_end_core; - const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 - // cluster partition - partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster); - if (token_start_cluster >= token_end_cluster) { - return; + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + // cluster partition + partition(cluster_id(), + cluster_num(), + (int)token_num, + 1, + &token_start_cluster, + &token_end_cluster); + if (token_start_cluster >= token_end_cluster) { + return; + } + int rows_cluster = + token_end_cluster - token_start_cluster; // total rows for a cluster + // core partition + partition(core_id(), + core_num(), + rows_cluster, + 1, + &token_start_core, + &token_end_core); + int rows_core = token_end_core - token_start_core; // total rows for a core + token_start_core += token_start_cluster; + token_end_core += token_start_cluster; + + int read_len; + for (int i = token_start_core; i < token_end_core; i += 1) { + int token_type, text_image_token_idx; + __global_ptr__ T* text_image_input = nullptr; + __global_ptr__ int* text_image_index = nullptr; + + GM2LM(token_type_ids + i, &token_type, sizeof(int)); + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { + text_image_input = image_input; + text_image_index = image_index; } - int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster - // core partition - partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core); - int rows_core = token_end_core - token_start_core; // total rows for a core - token_start_core += token_start_cluster; - token_end_core += token_start_cluster; + GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; - int read_len; - for (int i = token_start_core; i < token_end_core; i += 1) { - int token_type, text_image_token_idx; - __global_ptr__ T* text_image_input = nullptr; - __global_ptr__ int* text_image_index = nullptr; - - GM2LM(token_type_ids + i, &token_type, sizeof(int)); - if (token_type == 0) { - text_image_input = text_input; - text_image_index = text_index; - } else { - text_image_input = image_input; - text_image_index = image_index; - } - GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); - int input_offset = i * hidden_size; - int text_image_offset = text_image_token_idx * hidden_size; - - for (int j = 0; j < hidden_size; j += BUFSIZE) { - read_len = min(hidden_size - j, BUFSIZE); - GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len); - LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len); - } + for (int j = 0; j < hidden_size; j += BUFSIZE) { + read_len = min(hidden_size - j, BUFSIZE); + GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len); + LM2GM(input_lm, + text_image_input + text_image_offset + j, + sizeof(T) * read_len); } + } } template -__global__ void text_image_gather_scatter( - T* input, - T* text_input, - T* image_input, - int* token_type_ids, - int* text_index, - int* image_index, - int64_t token_num, - int64_t text_token_num, - int64_t image_token_num, - int64_t hidden_size, - bool is_scatter) { - int cid = core_id(); - int ncores = core_num(); - int clusterid = cluster_id(); - int nclusters = cluster_num(); - const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 - __simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32 - if (is_scatter) { - text_image_scatter( - input, text_input, image_input, token_type_ids, text_index, image_index, - token_num, text_token_num, image_token_num, hidden_size, input_lm); - } else { - text_image_gather( - input, text_input, image_input, token_type_ids, text_index, image_index, - token_num, text_token_num, image_token_num, hidden_size, input_lm); - } +__global__ void text_image_gather_scatter(T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + __simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32 + if (is_scatter) { + text_image_scatter(input, + text_input, + image_input, + token_type_ids, + text_index, + image_index, + token_num, + text_token_num, + image_token_num, + hidden_size, + input_lm); + } else { + text_image_gather(input, + text_input, + image_input, + token_type_ids, + text_index, + image_index, + token_num, + text_token_num, + image_token_num, + hidden_size, + input_lm); + } } - -#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \ - template __global__ void text_image_gather_scatter( \ - T* input, \ - T* text_input, \ - T* image_input, \ - int* token_type_ids, \ - int* text_index, \ - int* image_index, \ - int64_t token_num, \ - int64_t text_token_num, \ - int64_t image_token_num, \ - int64_t hidden_size, \ - bool is_scatter); +#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \ + template __global__ void text_image_gather_scatter( \ + T * input, \ + T * text_input, \ + T * image_input, \ + int* token_type_ids, \ + int* text_index, \ + int* image_index, \ + int64_t token_num, \ + int64_t text_token_num, \ + int64_t image_token_num, \ + int64_t hidden_size, \ + bool is_scatter); _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu index f8c972ef3..96112742d 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu @@ -23,75 +23,92 @@ namespace xpu3 { namespace plugin { -static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) { - for (int j = 0; j < size; j++) { - if (lm_x[j] == 0) { - lm_y1[j] = text_count; - text_count += 1; - } else { - lm_y2[j] = images_count; - images_count += 1; - } +static __device__ void do_calc(const _shared_ptr_ int* lm_x, + int* lm_y1, + int* lm_y2, + int64_t size, + int& text_count, + int& images_count) { + for (int j = 0; j < size; j++) { + if (lm_x[j] == 0) { + lm_y1[j] = text_count; + text_count += 1; + } else { + lm_y2[j] = images_count; + images_count += 1; } - mfence_lm_sm(); + } + mfence_lm_sm(); } -__global__ void text_image_index_out_kernel( - const int* token_type_ids, // x - int* text_index, // y1 - int* image_index, // y2 - const int64_t token_num) { - const int cid = core_id(); - const int tid = core_id() * cluster_num() + cluster_id(); - const int nthreads = core_num() * cluster_num(); - if (tid >= 1) return; - constexpr int BUFSIZE = 1024; - constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int); - const int64_t len = token_num; +__global__ void text_image_index_out_kernel(const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num) { + const int cid = core_id(); + const int tid = core_id() * cluster_num() + cluster_id(); + const int nthreads = core_num() * cluster_num(); + if (tid >= 1) return; + constexpr int BUFSIZE = 1024; + constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int); + const int64_t len = token_num; - __simd__ char buffer0[BUFSIZE * 3]; - __simd__ char buffer1[BUFSIZE * 3]; - __simd__ __shared__ char buffer2[64][BUFSIZE * 2]; + __simd__ char buffer0[BUFSIZE * 3]; + __simd__ char buffer1[BUFSIZE * 3]; + __simd__ __shared__ char buffer2[64][BUFSIZE * 2]; - DoublePtr> buffer_ptr_x((SmPtr((_shared_ptr_ int*)buffer2[cid]))); - TriplePtr> buffer_ptr_y1((LmPtr((int*)buffer0))); - TriplePtr> buffer_ptr_y2((LmPtr((int*)buffer1))); - int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64); - int64_t i = tid * buflen; - int read_size = 0; - int offset = nthreads * buflen; + DoublePtr> buffer_ptr_x( + (SmPtr((_shared_ptr_ int*)buffer2[cid]))); + TriplePtr> buffer_ptr_y1( + (LmPtr((int*)buffer0))); + TriplePtr> buffer_ptr_y2( + (LmPtr((int*)buffer1))); + int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64); + int64_t i = tid * buflen; + int read_size = 0; + int offset = nthreads * buflen; - int text_count = 0; - int images_count = 0; + int text_count = 0; + int images_count = 0; - if (i < len) { - read_size = min(buflen, len - i); - buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size); - buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size); - buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size); - mfence(); - } - while (i < len && i + offset < len) { - i = i + offset; - int read_size_next = min(buflen, len - i); - buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next); - buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next); - buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next); + if (i < len) { + read_size = min(buflen, len - i); + buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size); + buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size); + buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size); + mfence(); + } + while (i < len && i + offset < len) { + i = i + offset; + int read_size_next = min(buflen, len - i); + buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next); + buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next); + buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next); - do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count); + do_calc(buffer_ptr_x.ptr, + buffer_ptr_y1.ptr, + buffer_ptr_y2.ptr, + read_size, + text_count, + images_count); - buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size); - buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size); - buffer_ptr_x.toggle(); - buffer_ptr_y1.toggle(); - buffer_ptr_y2.toggle(); - read_size = read_size_next; - } - if (i < len) { - do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count); - buffer_ptr_y1.gm_store_async(text_index + i, read_size); - buffer_ptr_y2.gm_store(image_index + i, read_size); - } + buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size); + buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size); + buffer_ptr_x.toggle(); + buffer_ptr_y1.toggle(); + buffer_ptr_y2.toggle(); + read_size = read_size_next; + } + if (i < len) { + do_calc(buffer_ptr_x.ptr, + buffer_ptr_y1.ptr, + buffer_ptr_y2.ptr, + read_size, + text_count, + images_count); + buffer_ptr_y1.gm_store_async(text_index + i, read_size); + buffer_ptr_y2.gm_store(image_index + i, read_size); + } } } // namespace plugin } // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs.xpu index e1bb6b57c..0da8743fe 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs.xpu @@ -46,7 +46,8 @@ __global__ void update_inputs(bool *not_need_stop, int seq_len_decoder_update = stop_flag_now ? 0 - : (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder) : seq_len_decoder + 1); + : (seq_len_encoder > 0 ? (seq_len_encoder + seq_len_decoder) + : seq_len_decoder + 1); int seq_len_this_time_update = !stop_flag_now; int seq_len_encoder_update = 0; mfence_lm(); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu index 8eb87c12d..e38b47bf3 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu @@ -4,32 +4,30 @@ // #include // using namespace std; -#include "xpu/kernel/xtdk_io.h" #include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_io.h" namespace xpu3 { namespace plugin { __global__ void update_inputs_v1(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size) { - - + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { // std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl; int cid = core_id(); int ncores = core_num(); @@ -41,74 +39,83 @@ __global__ void update_inputs_v1(bool *not_need_stop, const int max_bs = 1024; __shared__ bool stop_flags_sm[max_bs]; __shared__ int stop_flags_int_sm[max_bs]; - if(cid == 0){ + if (cid == 0) { GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz); } sync_all(); - for(int i = cid; i < bsz; i+= ncores){ - if(i < bsz){ - stop_flags_sm[i] = stop_flags[i]; - stop_flags_int_sm[i] = static_cast(stop_flags_sm[i]); - }else{ - stop_flags_sm[i] = true; - stop_flags_int_sm[i] = 1; + for (int i = cid; i < bsz; i += ncores) { + if (i < bsz) { + stop_flags_sm[i] = stop_flags[i]; + stop_flags_int_sm[i] = static_cast(stop_flags_sm[i]); + } else { + stop_flags_sm[i] = true; + stop_flags_int_sm[i] = 1; } - if(i= + prompt_lens_update) { + seq_len_decoder_update = + seq_len_this_time_update + seq_len_decoder_update; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + seq_len_this_time_update = 1; + LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + seq_lens_encoder_update = 0; + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t input_ids_update; + GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); + LM2GM(&input_ids_update, + input_ids + i * input_ids_stride, + sizeof(int64_t)); + // to judge whether block is not enough + if (seq_len_this_time_update != 0 && + block_tables[i * block_num_per_seq + + seq_len_decoder_update / block_size] == -1) { + is_block_step[i] = true; + seq_len_this_time_update = 0; + LM2GM( + &seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + stop_flags_sm[i] = true; + SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool)); + LM2GM(&seq_len_decoder_update, + step_seq_lens_decoder + i, + sizeof(int)); + seq_len_decoder_update = 0; LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); - LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); - }else{ - GM2LM(seq_lens_this_time+i, &seq_len_this_time_update, sizeof(int)); - GM2LM(seq_lens_decoder+i, &seq_len_decoder_update, sizeof(int)); - GM2LM(seq_lens_encoder+i, &seq_lens_encoder_update, sizeof(int)); - int sum_of_seq_lens_this_time_and_seq_lens_decoder = seq_len_this_time_update + seq_len_decoder_update; - int prompt_lens_update = 0; - GM2LM(prompt_lens+i, &prompt_lens_update, sizeof(int64_t)); - // decoding - if(sum_of_seq_lens_this_time_and_seq_lens_decoder >= prompt_lens_update){ - seq_len_decoder_update = seq_len_this_time_update + seq_len_decoder_update; - LM2GM(&seq_len_decoder_update, seq_lens_decoder+i, sizeof(int)); - seq_len_this_time_update = 1; - LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); - seq_lens_encoder_update = 0; - LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); - int64_t input_ids_update; - GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); - LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t)); - // to judge whether block is not enough - if(seq_len_this_time_update != 0 && block_tables[i * block_num_per_seq + seq_len_decoder_update/block_size] == -1){ - is_block_step[i] = true; - seq_len_this_time_update = 0; - LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); - stop_flags_sm[i] = true; - SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool)); - LM2GM(&seq_len_decoder_update, step_seq_lens_decoder+i, sizeof(int)); - seq_len_decoder_update = 0; - LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); - seq_len_decoder_update = 0; - LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); - stop_flags_int_sm[i] = 1; - } - }else{ - stop_flags_sm[i] = true; - SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool)); - seq_len_this_time_update = 0; - LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); - seq_len_decoder_update = 0; - seq_lens_encoder_update = 0; - LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); - LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); - int64_t topk_ids_update = -1; - LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t)); - stop_flags_int_sm[i] = 1; - } - + seq_len_decoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + stop_flags_int_sm[i] = 1; + } + } else { + stop_flags_sm[i] = true; + SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool)); + seq_len_this_time_update = 0; + LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + seq_len_decoder_update = 0; + seq_lens_encoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t topk_ids_update = -1; + LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t)); + stop_flags_int_sm[i] = 1; } + } } } sync_all(); diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_value_by_repeat_times.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_value_by_repeat_times.xpu index b1cb1dbf0..f85111768 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_value_by_repeat_times.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_value_by_repeat_times.xpu @@ -6,16 +6,16 @@ namespace xpu3 { namespace plugin { -__device__ void do_cast(const int* xlm, float* ylm, int64_t len) { - for (int64_t i = 0; i < len; i += 32) { - int32x16_t xl = vload_lm_int32x16(xlm + i); - int32x16_t xh = vload_lm_int32x16(xlm + i + 16); - float32x16_t yl = vfix2float(xl); - float32x16_t yh = vfix2float(xh); - vstore_lm_float32x16(ylm + i, yl); - vstore_lm_float32x16(ylm + i + 16, yh); - } - mfence_lm(); +__device__ void do_cast(const int *xlm, float *ylm, int64_t len) { + for (int64_t i = 0; i < len; i += 32) { + int32x16_t xl = vload_lm_int32x16(xlm + i); + int32x16_t xh = vload_lm_int32x16(xlm + i + 16); + float32x16_t yl = vfix2float(xl); + float32x16_t yh = vfix2float(xh); + vstore_lm_float32x16(ylm + i, yl); + vstore_lm_float32x16(ylm + i + 16, yh); + } + mfence_lm(); } template @@ -124,7 +124,8 @@ __global__ void update_value_by_repeat_times_simd( int nthreads = cluster_num() * ncores; int start = -1; int end = -1; - partition(thread_id, nthreads, static_cast(bs * length), 16, &start, &end); + partition( + thread_id, nthreads, static_cast(bs * length), 16, &start, &end); const int param_len = 256; // ncores = 64 for xpu3 @@ -178,14 +179,28 @@ __global__ void update_value_by_repeat_times_simd( alpha = alpha_buf[param_idx]; beta = beta_buf[param_idx]; gamma = gamma_buf[param_idx]; - time_mask = svneq_float32x16(0.f, time_); // time != 0 mask - logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask - time_ = svmul_float32x16(beta, time_); // time * beta - time_ = svadd_float32x16(gamma, time_); // time * beta + gamma - logits_ = svmul_float32x16_mh(alpha, logits_, logits_, (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do alpha * logit - logits_ = svmul_float32x16_mh(1.0f / alpha, logits_, logits_, (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha - logits_ = vvsub_float32x16_mh(logits_, time_, logits_, time_mask); // when time != 0, do logit = logit - time * beta - gamma; - logits_ = svmul_float32x16(1.0f / temperature, logits_); // logit / temperature + time_mask = svneq_float32x16(0.f, time_); // time != 0 mask + logit_mask = svle_float32x16(0.f, logits_); // logit >= 0 mask + time_ = svmul_float32x16(beta, time_); // time * beta + time_ = svadd_float32x16(gamma, time_); // time * beta + gamma + logits_ = svmul_float32x16_mh( + alpha, + logits_, + logits_, + (time_mask & ~logit_mask)); // when time != 0 && logit < 0, do + // alpha * logit + logits_ = svmul_float32x16_mh( + 1.0f / alpha, + logits_, + logits_, + (time_mask & logit_mask)); // when time != 0 && >=0, do logit / alpha + logits_ = vvsub_float32x16_mh(logits_, + time_, + logits_, + time_mask); // when time != 0, do logit = + // logit - time * beta - gamma; + logits_ = svmul_float32x16(1.0f / temperature, + logits_); // logit / temperature vstore_lm_float32x16(logits_lm + j, logits_); } mfence_lm(); @@ -195,14 +210,14 @@ __global__ void update_value_by_repeat_times_simd( } #define _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(DATA_TYPE) \ - template __global__ void update_value_by_repeat_times_simd( \ - const int *repeat_times, \ - const DATA_TYPE *penalty_scores, \ - const DATA_TYPE *frequency_score, \ - const DATA_TYPE *presence_score, \ - const float *temperatures, \ - DATA_TYPE *logits, \ - const int64_t bs, \ + template __global__ void update_value_by_repeat_times_simd( \ + const int *repeat_times, \ + const DATA_TYPE *penalty_scores, \ + const DATA_TYPE *frequency_score, \ + const DATA_TYPE *presence_score, \ + const float *temperatures, \ + DATA_TYPE *logits, \ + const int64_t bs, \ const int64_t length); _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float); _XPU_DEF__UPDATE_VALUE_BY_REPEAT_TIMES_SIMD(float16); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp index 121e06192..94f235213 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp @@ -20,12 +20,16 @@ namespace xpu3 { namespace plugin { template -__attribute__((global)) void -eb_adjust_batch(TX *src, TY *dst, int *encoder_seqs_lods, - int *encoder_batch_map, int *decoder_batch_map, int en_batch, - int de_batch, int64_t copy_size); -} // namespace plugin -} // namespace xpu3 +__attribute__((global)) void eb_adjust_batch(TX *src, + TY *dst, + int *encoder_seqs_lods, + int *encoder_batch_map, + int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size); +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { @@ -33,10 +37,15 @@ namespace api { namespace plugin { template -static int -cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, - const int *encoder_batch_map, const int *decoder_batch_map, - int en_batch, int de_batch, int64_t hidden_dim) { +static int cpu_wrapper(api::Context *ctx, + const TX *x, + TY *y, + const int *encoder_seqs_lods, + const int *encoder_batch_map, + const int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t hidden_dim) { int ret = 0; int cur_offset = 0; int en_idx = 0; @@ -48,7 +57,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, int cpy_m = 0; if (de_batch > 0 && decoder_batch_map[de_idx] == i) { cpy_m = 1; - ret = api::cast(ctx, x + cur_offset * hidden_dim, + ret = api::cast(ctx, + x + cur_offset * hidden_dim, y + (encoder_len_total + de_idx) * hidden_dim, cpy_m * hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); @@ -56,7 +66,8 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, } if (en_batch > 0 && encoder_batch_map[en_idx] == i) { cpy_m = encoder_seqs_lods[en_idx + 1] - encoder_seqs_lods[en_idx]; - ret = api::cast(ctx, x + cur_offset * hidden_dim, + ret = api::cast(ctx, + x + cur_offset * hidden_dim, y + encoder_seqs_lods[en_idx] * hidden_dim, cpy_m * hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); @@ -69,11 +80,15 @@ cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, } template -static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y, - api::VectorParam &encoder_seqs_lods, // NOLINT - api::VectorParam &encoder_batch_map, // NOLINT - api::VectorParam &decoder_batch_map, // NOLINT - int en_batch, int de_batch, int64_t hidden_dim) { +static int xpu3_wrapper(api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int en_batch, + int de_batch, + int64_t hidden_dim) { using XPU_INDEX_TYPE_TX = typename XPUIndexType::type; using XPU_INDEX_TYPE_TY = typename XPUIndexType::type; auto eb_adjust_batch_kernel = @@ -81,17 +96,23 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y, // NOTE: Don't change 16 to 64, because kernel use gsm eb_adjust_batch_kernel<<ncluster(), 16, ctx->xpu_stream>>>( reinterpret_cast(const_cast(x)), - reinterpret_cast(y), encoder_seqs_lods.xpu, - encoder_batch_map.xpu, decoder_batch_map.xpu, en_batch, de_batch, + reinterpret_cast(y), + encoder_seqs_lods.xpu, + encoder_batch_map.xpu, + decoder_batch_map.xpu, + en_batch, + de_batch, hidden_dim); return api::SUCCESS; } template -int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, - api::VectorParam &encoder_seqs_lods, // NOLINT - api::VectorParam &encoder_batch_map, // NOLINT - api::VectorParam &decoder_batch_map, // NOLINT +int eb_adjust_batch(api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT int64_t hidden_dim) { // int dev_id = -1; // xpu_current_device(&dev_id); @@ -101,8 +122,13 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY); - WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map, - decoder_batch_map, hidden_dim); + WRAPPER_DUMP_PARAM6(ctx, + x, + y, + encoder_seqs_lods, + encoder_batch_map, + decoder_batch_map, + hidden_dim); WRAPPER_DUMP(ctx); int encoder_batch = encoder_batch_map.len; int total_batch = encoder_batch + decoder_batch_map.len; @@ -126,9 +152,14 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch) } if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, x, y, encoder_seqs_lods.cpu, - encoder_batch_map.cpu, decoder_batch_map.cpu, - encoder_batch_map.len, decoder_batch_map.len, + return cpu_wrapper(ctx, + x, + y, + encoder_seqs_lods.cpu, + encoder_batch_map.cpu, + decoder_batch_map.cpu, + encoder_batch_map.len, + decoder_batch_map.len, hidden_dim); } if (ctx->dev().type() == api::kXPU3) { @@ -139,18 +170,27 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, encoder_batch_map.to_xpu(RAII_GUARD); api::VectorParam decoder_batch_map_xpu = decoder_batch_map.to_xpu(RAII_GUARD); - return xpu3_wrapper(ctx, x, y, encoder_seqs_lods_xpu, - encoder_batch_map_xpu, decoder_batch_map_xpu, - encoder_batch_map.len, decoder_batch_map.len, + return xpu3_wrapper(ctx, + x, + y, + encoder_seqs_lods_xpu, + encoder_batch_map_xpu, + decoder_batch_map_xpu, + encoder_batch_map.len, + decoder_batch_map.len, hidden_dim); } WRAPPER_UNIMPLEMENTED(ctx); } -#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \ - template int eb_adjust_batch( \ - api::Context *, const TX *, TY *, api::VectorParam &, \ - api::VectorParam &, api::VectorParam &, int64_t); +#define INSTANTIATION_EB_ADJUST_BATCH(TX, TY) \ + template int eb_adjust_batch(api::Context *, \ + const TX *, \ + TY *, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + int64_t); INSTANTIATION_EB_ADJUST_BATCH(float16, float16); INSTANTIATION_EB_ADJUST_BATCH(bfloat16, bfloat16); @@ -163,7 +203,7 @@ INSTANTIATION_EB_ADJUST_BATCH(bfloat16, float); INSTANTIATION_EB_ADJUST_BATCH(float, bfloat16); INSTANTIATION_EB_ADJUST_BATCH(int32_t, int32_t); INSTANTIATION_EB_ADJUST_BATCH(int64_t, int64_t); -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp index 5ee3c8833..ac3d5731f 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_gather_next_token.cpp @@ -20,62 +20,92 @@ namespace xpu3 { namespace plugin { template -__attribute__((global)) void -eb_gather_next_token(TX *src, TY *dst, int *encoder_seqs_lods, - int *encoder_batch_map, int *decoder_batch_map, - int en_batch, int de_batch, int64_t copy_size); -} // namespace plugin -} // namespace xpu3 +__attribute__((global)) void eb_gather_next_token(TX *src, + TY *dst, + int *encoder_seqs_lods, + int *encoder_batch_map, + int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size); +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { template -static int -cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, - const int *encoder_batch_map, const int *decoder_batch_map, - int en_batch, int de_batch, int64_t hidden_dim) { +static int cpu_wrapper(api::Context *ctx, + const TX *x, + TY *y, + const int *encoder_seqs_lods, + const int *encoder_batch_map, + const int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t hidden_dim) { int ret = 0; int encoder_len_total = encoder_seqs_lods[en_batch]; for (int i = 0; i < en_batch; i++) { - ret = - api::cast(ctx, x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim, - y + encoder_batch_map[i] * hidden_dim, hidden_dim); + ret = api::cast(ctx, + x + (encoder_seqs_lods[i + 1] - 1) * hidden_dim, + y + encoder_batch_map[i] * hidden_dim, + hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); } for (int i = 0; i < de_batch; i++) { - ret = api::cast(ctx, x + (encoder_len_total + i) * hidden_dim, - y + decoder_batch_map[i] * hidden_dim, hidden_dim); + ret = api::cast(ctx, + x + (encoder_len_total + i) * hidden_dim, + y + decoder_batch_map[i] * hidden_dim, + hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); } return api::SUCCESS; } template -static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y, - api::VectorParam &encoder_seqs_lods, // NOLINT - api::VectorParam &encoder_batch_map, // NOLINT - api::VectorParam &decoder_batch_map, // NOLINT - int en_batch, int de_batch, int64_t hidden_dim) { +static int xpu3_wrapper(api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int en_batch, + int de_batch, + int64_t hidden_dim) { auto eb_gather_next_token_kernel = xpu3::plugin::eb_gather_next_token; // NOTE: Don't change 16 to 64, because kernel use gsm eb_gather_next_token_kernel<<ncluster(), 16, ctx->xpu_stream>>>( - const_cast(x), y, encoder_seqs_lods.xpu, encoder_batch_map.xpu, - decoder_batch_map.xpu, en_batch, de_batch, hidden_dim); + const_cast(x), + y, + encoder_seqs_lods.xpu, + encoder_batch_map.xpu, + decoder_batch_map.xpu, + en_batch, + de_batch, + hidden_dim); return api::SUCCESS; } template -int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y, - api::VectorParam &encoder_seqs_lods, // NOLINT - api::VectorParam &encoder_batch_map, // NOLINT - api::VectorParam &decoder_batch_map, // NOLINT - int64_t hidden_dim) { +int eb_gather_next_token( + api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int64_t hidden_dim) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_gather_next_token", TX, TY); - WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, encoder_batch_map, - decoder_batch_map, hidden_dim); + WRAPPER_DUMP_PARAM6(ctx, + x, + y, + encoder_seqs_lods, + encoder_batch_map, + decoder_batch_map, + hidden_dim); WRAPPER_DUMP(ctx); int encoder_batch = encoder_batch_map.len; int batch = encoder_batch + decoder_batch_map.len; @@ -99,9 +129,14 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y, WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0); } if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, x, y, encoder_seqs_lods.cpu, - encoder_batch_map.cpu, decoder_batch_map.cpu, - encoder_batch_map.len, decoder_batch_map.len, + return cpu_wrapper(ctx, + x, + y, + encoder_seqs_lods.cpu, + encoder_batch_map.cpu, + decoder_batch_map.cpu, + encoder_batch_map.len, + decoder_batch_map.len, hidden_dim); } if (ctx->dev().type() == api::kXPU3) { @@ -112,17 +147,26 @@ int eb_gather_next_token(api::Context *ctx, const TX *x, TY *y, encoder_batch_map.to_xpu(RAII_GUARD); api::VectorParam decoder_batch_map_xpu = decoder_batch_map.to_xpu(RAII_GUARD); - return xpu3_wrapper(ctx, x, y, encoder_seqs_lods_xpu, - encoder_batch_map_xpu, decoder_batch_map_xpu, - encoder_batch_map.len, decoder_batch_map.len, + return xpu3_wrapper(ctx, + x, + y, + encoder_seqs_lods_xpu, + encoder_batch_map_xpu, + decoder_batch_map_xpu, + encoder_batch_map.len, + decoder_batch_map.len, hidden_dim); } WRAPPER_UNIMPLEMENTED(ctx); } -#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \ - template int eb_gather_next_token( \ - api::Context *, const TX *, TY *, api::VectorParam &, \ - api::VectorParam &, api::VectorParam &, int64_t); +#define INSTANTIATION_EB_GATHER_NEXT_TOKEN(TX, TY) \ + template int eb_gather_next_token(api::Context *, \ + const TX *, \ + TY *, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + int64_t); INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, float16); INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, bfloat16); @@ -133,7 +177,7 @@ INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float16); INSTANTIATION_EB_GATHER_NEXT_TOKEN(float16, bfloat16); INSTANTIATION_EB_GATHER_NEXT_TOKEN(bfloat16, float); INSTANTIATION_EB_GATHER_NEXT_TOKEN(float, bfloat16); -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/free_and_dispatch_block.cpp index 88e00b9e1..c2d2789c9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/free_and_dispatch_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/free_and_dispatch_block.cpp @@ -12,211 +12,304 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { __attribute__((global)) void free_and_dispatch_block( - bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz, - const int block_size, const int block_num_per_seq, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, const int max_decoder_block_num); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, - int *seq_lens_decoder, int *block_tables, - int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, - int *used_list_len, int *free_list, int *free_list_len, - int64_t *first_token_ids, const int bsz, - const int block_size, const int block_num_per_seq, +static int cpu_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, const int max_decoder_block_num) { + for (int i = 0; i < bsz; i++) { + int *block_table_now = block_tables + i * block_num_per_seq; + if (stop_flags[i] && !is_block_step[i]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[i]; + const int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + for (int j = 0; j < decoder_used_len; j++) { + free_list[ori_free_list_len + j] = + block_table_now[encoder_block_len + j]; + block_table_now[encoder_block_len + j] = -1; + } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + } + } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = need_block_len[0]; + need_block_len[0] += 1; + need_block_list[ori_need_block_len] = i; + } + } + + while (need_block_len[0] > free_list_len[0]) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + int max_used_list_len_id = 0; + int max_used_list_len = 0; for (int i = 0; i < bsz; i++) { - int *block_table_now = block_tables + i * block_num_per_seq; - if (stop_flags[i] && !is_block_step[i]) { - // 回收block块 - const int encoder_block_len = encoder_block_lens[i]; - const int decoder_used_len = used_list_len[i]; - if (decoder_used_len > 0) { - const int ori_free_list_len = free_list_len[0]; - free_list_len[0] += decoder_used_len; - for (int j = 0; j < decoder_used_len; j++) { - free_list[ori_free_list_len + j] = - block_table_now[encoder_block_len + j]; - block_table_now[encoder_block_len + j] = -1; - } - encoder_block_lens[i] = 0; - used_list_len[i] = 0; - } - } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { - // 统计需要分配block的位置和总数 - const int ori_need_block_len = need_block_len[0]; - need_block_len[0] += 1; - need_block_list[ori_need_block_len] = i; - } + const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; + if (used_block_num > max_used_list_len) { + max_used_list_len_id = i; + max_used_list_len = used_block_num; + } } - while (need_block_len[0] > free_list_len[0]) { - // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len - int max_used_list_len_id = 0; - int max_used_list_len = 0; - for (int i = 0; i < bsz; i++) { - const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; - if (used_block_num > max_used_list_len) { - max_used_list_len_id = i; - max_used_list_len = used_block_num; - } - } - - const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; - int *block_table_now = - block_tables + max_used_list_len_id * block_num_per_seq; - for (int i = 0; i < max_used_list_len; i++) { - free_list[free_list_len[0] + i] = - block_table_now[encoder_block_len + i]; - block_table_now[encoder_block_len + i] = -1; - } - step_block_list[step_len[0]] = max_used_list_len_id; - step_len[0] += 1; - free_list_len[0] += max_used_list_len; - stop_flags[max_used_list_len_id] = true; - is_block_step[max_used_list_len_id] = true; - seq_lens_this_time[max_used_list_len_id] = 0; - seq_lens_decoder[max_used_list_len_id] = 0; + const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; + int *block_table_now = + block_tables + max_used_list_len_id * block_num_per_seq; + for (int i = 0; i < max_used_list_len; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; } + step_block_list[step_len[0]] = max_used_list_len_id; + step_len[0] += 1; + free_list_len[0] += max_used_list_len; + stop_flags[max_used_list_len_id] = true; + is_block_step[max_used_list_len_id] = true; + seq_lens_this_time[max_used_list_len_id] = 0; + seq_lens_decoder[max_used_list_len_id] = 0; + } - // 为需要block的位置分配block,每个位置分配一个block - for (int i = 0; i < bsz; i++) { - if (i < need_block_len[0]) { - const int need_block_id = need_block_list[i]; - if (!stop_flags[need_block_id]) { - // 如果需要的位置正好是上一步中被释放的位置,不做处理 - used_list_len[need_block_id] += 1; - const int ori_free_list_len = free_list_len[0]; - free_list_len[0]--; - int *block_table_now = - block_tables + need_block_id * block_num_per_seq; - block_table_now[seq_lens_decoder[need_block_id] / block_size] = - free_list[ori_free_list_len - 1]; - } - need_block_list[i] = -1; - } + // 为需要block的位置分配block,每个位置分配一个block + for (int i = 0; i < bsz; i++) { + if (i < need_block_len[0]) { + const int need_block_id = need_block_list[i]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = free_list_len[0]; + free_list_len[0]--; + int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[seq_lens_decoder[need_block_id] / block_size] = + free_list[ori_free_list_len - 1]; + } + need_block_list[i] = -1; } + } - // 计算可以复原的query id - int ori_step_len = step_len[0]; - if (ori_step_len > 0) { - int ori_free_list_len = free_list_len[0]; - int ori_step_block_id = step_block_list[ori_step_len - 1]; - int tmp_used_len = used_list_len[ori_step_block_id]; - // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) - int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 - : tmp_used_len; - while (ori_step_len > 0 && ori_free_list_len >= used_len) { - recover_block_list[recover_len[0]] = ori_step_block_id; - is_block_step[ori_step_block_id] = false; - used_list_len[ori_step_block_id] = used_len; - ori_free_list_len -= used_len; - step_block_list[ori_step_len - 1] = -1; - step_len[0] -= 1; - recover_len[0] += 1; - ori_step_len = step_len[0]; - if (ori_step_len > 0) { - ori_step_block_id = step_block_list[ori_step_len - 1]; - tmp_used_len = used_list_len[ori_step_block_id]; - used_len = tmp_used_len < max_decoder_block_num - ? tmp_used_len + 1 - : tmp_used_len; - } - } - need_block_len[0] = 0; + // 计算可以复原的query id + int ori_step_len = step_len[0]; + if (ori_step_len > 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = + tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + while (ori_step_len > 0 && ori_free_list_len >= used_len) { + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 + : tmp_used_len; + } } - return api::SUCCESS; + need_block_len[0] = 0; + } + return api::SUCCESS; } -static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, - int *seq_lens_decoder, int *block_tables, - int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, - int *recover_len, int *need_block_list, - int *need_block_len, int *used_list_len, int *free_list, - int *free_list_len, int64_t *first_token_ids, - const int bsz, const int block_size, +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, const int block_num_per_seq, const int max_decoder_block_num) { - using XPU_INT64 = typename XPUIndexType::type; - auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block; - free_and_dispatch_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, - reinterpret_cast(first_token_ids), bsz, block_size, - block_num_per_seq, max_decoder_block_num); - return api::SUCCESS; + using XPU_INT64 = typename XPUIndexType::type; + auto free_and_dispatch_block_kernel = xpu3::plugin::free_and_dispatch_block; + free_and_dispatch_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + reinterpret_cast(first_token_ids), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num); + return api::SUCCESS; } -int free_and_dispatch_block(Context *ctx, bool *stop_flags, - int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, +int free_and_dispatch_block(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, - int *recover_len, int *need_block_list, - int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, - int64_t *first_token_ids, const int bsz, - const int block_size, const int block_num_per_seq, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, const int max_decoder_block_num) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float); - WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, - block_tables, encoder_block_lens, is_block_step); - WRAPPER_DUMP_PARAM6(ctx, step_block_list, step_len, recover_block_list, - recover_len, need_block_list, need_block_len); - WRAPPER_DUMP_PARAM4(ctx, used_list_len, free_list, free_list_len, - first_token_ids); - WRAPPER_DUMP_PARAM4(ctx, bsz, block_size, block_num_per_seq, + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "free_and_dispatch_block", float); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step); + WRAPPER_DUMP_PARAM6(ctx, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len); + WRAPPER_DUMP_PARAM4( + ctx, used_list_len, free_list, free_list_len, first_token_ids); + WRAPPER_DUMP_PARAM4( + ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + bsz, + block_size, + block_num_per_seq, max_decoder_block_num); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper( - ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, first_token_ids, bsz, - block_size, block_num_per_seq, max_decoder_block_num); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper( - ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, first_token_ids, bsz, - block_size, block_num_per_seq, max_decoder_block_num); - } - WRAPPER_UNIMPLEMENTED(ctx); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp index 2960f2dd6..e0f7c013f 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/get_padding_offset.cpp @@ -12,120 +12,178 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { -__attribute__((global)) void -get_padding_offset(int *padding_offset, int *cum_offsets_out, int *cu_seqlens_q, - int *cu_seqlens_k, const int *cum_offsets, - const int *seq_lens, const int max_seq_len, const int bs); -__attribute__((global)) void -remove_padding(int64_t *x_remove_padding, const int64_t *input_data, - const int *seq_lens, const int *cum_offsets, - const int sequence_length, const int bs); +__attribute__((global)) void get_padding_offset(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs); +__attribute__((global)) void remove_padding(int64_t *x_remove_padding, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length, + const int bs); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int get_padding_offset_cpu(int *padding_offset, int *cum_offsets_out, - int *cu_seqlens_q, int *cu_seqlens_k, - const int *cum_offsets, const int *seq_lens, - const int max_seq_len, const int bs) { - for (int i = 0; i < bs; i++) { - int cum_offset = i == 0 ? 0 : cum_offsets[i - 1]; - for (int j = 0; j < seq_lens[i]; j++) { - padding_offset[i * max_seq_len - cum_offset + j] = cum_offset; - } - cum_offsets_out[i] = cum_offset; - int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i]; - cu_seqlens_q[i + 1] = cum_seq_len; - cu_seqlens_k[i + 1] = cum_seq_len; +static int get_padding_offset_cpu(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + for (int i = 0; i < bs; i++) { + int cum_offset = i == 0 ? 0 : cum_offsets[i - 1]; + for (int j = 0; j < seq_lens[i]; j++) { + padding_offset[i * max_seq_len - cum_offset + j] = cum_offset; } - return api::SUCCESS; + cum_offsets_out[i] = cum_offset; + int cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i]; + cu_seqlens_q[i + 1] = cum_seq_len; + cu_seqlens_k[i + 1] = cum_seq_len; + } + return api::SUCCESS; } static int remove_padding_cpu(int64_t *x_remove_padding, - const int64_t *input_data, const int *seq_lens, - const int *cum_offsets, const int sequence_length, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length, const int bs) { - for (int i = 0; i < bs; i++) { - for (int j = 0; j < seq_lens[i]; j++) { - const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j; - const int src_seq_id = i * sequence_length + j; - x_remove_padding[tgt_seq_id] = input_data[src_seq_id]; - } + for (int i = 0; i < bs; i++) { + for (int j = 0; j < seq_lens[i]; j++) { + const int tgt_seq_id = i * sequence_length - cum_offsets[i] + j; + const int src_seq_id = i * sequence_length + j; + x_remove_padding[tgt_seq_id] = input_data[src_seq_id]; } - return api::SUCCESS; + } + return api::SUCCESS; } -static int cpu_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out, - int *cu_seqlens_q, int *cu_seqlens_k, - int64_t *x_remove_padding, const int64_t *input_ids, - const int *cum_offsets, const int *seq_lens, - const int max_seq_len, const int bs) { - get_padding_offset_cpu(padding_offset, cum_offsets_out, cu_seqlens_q, - cu_seqlens_k, cum_offsets, seq_lens, max_seq_len, - bs); - remove_padding_cpu(x_remove_padding, input_ids, seq_lens, cum_offsets_out, - max_seq_len, bs); - return api::SUCCESS; +static int cpu_wrapper(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + get_padding_offset_cpu(padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bs); + remove_padding_cpu( + x_remove_padding, input_ids, seq_lens, cum_offsets_out, max_seq_len, bs); + return api::SUCCESS; } -static int xpu3_wrapper(Context *ctx, int *padding_offset, int *cum_offsets_out, - int *cu_seqlens_q, int *cu_seqlens_k, - int64_t *x_remove_padding, const int64_t *input_ids, - const int *cum_offsets, const int *seq_lens, - const int max_seq_len, const int bs) { - using XPU_INT64 = typename XPUIndexType::type; - auto get_padding_offset = xpu3::plugin::get_padding_offset; - auto remove_padding = xpu3::plugin::remove_padding; - get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( - padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, - cum_offsets, seq_lens, max_seq_len, bs); - remove_padding<<ncluster(), 64, ctx->xpu_stream>>>( - reinterpret_cast(x_remove_padding), - reinterpret_cast(input_ids), seq_lens, - cum_offsets_out, max_seq_len, bs); - return api::SUCCESS; +static int xpu3_wrapper(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + using XPU_INT64 = typename XPUIndexType::type; + auto get_padding_offset = xpu3::plugin::get_padding_offset; + auto remove_padding = xpu3::plugin::remove_padding; + get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>(padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bs); + remove_padding<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(x_remove_padding), + reinterpret_cast(input_ids), + seq_lens, + cum_offsets_out, + max_seq_len, + bs); + return api::SUCCESS; } -int get_padding_offset(Context *ctx, int *padding_offset, int *cum_offsets_out, - int *cu_seqlens_q, int *cu_seqlens_k, - int64_t *x_remove_padding, const int64_t *input_ids, - const int *cum_offsets, const int *seq_lens, - const int max_seq_len, const int bs) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int); - WRAPPER_DUMP_PARAM4(ctx, padding_offset, cum_offsets_out, cu_seqlens_q, - cu_seqlens_k); - WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, - seq_lens); - WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q, - cu_seqlens_k, x_remove_padding, input_ids, - cum_offsets, seq_lens, max_seq_len, bs); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, padding_offset, cum_offsets_out, cu_seqlens_q, - cu_seqlens_k, x_remove_padding, input_ids, - cum_offsets, seq_lens, max_seq_len, bs); - } - WRAPPER_UNIMPLEMENTED(ctx); +int get_padding_offset(Context *ctx, + int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + int64_t *x_remove_padding, + const int64_t *input_ids, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len, + const int bs) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "get_padding_offset", int); + WRAPPER_DUMP_PARAM4( + ctx, padding_offset, cum_offsets_out, cu_seqlens_q, cu_seqlens_k); + WRAPPER_DUMP_PARAM4(ctx, x_remove_padding, input_ids, cum_offsets, seq_lens); + WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bs); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + x_remove_padding, + input_ids, + cum_offsets, + seq_lens, + max_seq_len, + bs); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + padding_offset, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + x_remove_padding, + input_ids, + cum_offsets, + seq_lens, + max_seq_len, + bs); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp index a62937941..36d2d446b 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_postprocess.cpp @@ -76,19 +76,20 @@ static int cpu_wrapper( } static int xpu3_wrapper(Context* ctx, - const int64_t* base_model_draft_tokens, - int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const bool* base_model_stop_flags, - int bsz, - int base_model_draft_token_len) { - xpu3::plugin::draft_model_postprocess<<ncluster(), 64, ctx->xpu_stream>>>( - reinterpret_cast(base_model_draft_tokens), - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_stop_flags, - bsz, - base_model_draft_token_len); + const int64_t* base_model_draft_tokens, + int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const bool* base_model_stop_flags, + int bsz, + int base_model_draft_token_len) { + xpu3::plugin:: + draft_model_postprocess<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(base_model_draft_tokens), + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); return api::SUCCESS; } @@ -124,12 +125,12 @@ int draft_model_postprocess(Context* ctx, } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, - base_model_draft_tokens, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_stop_flags, - bsz, - base_model_draft_token_len); + base_model_draft_tokens, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_stop_flags, + bsz, + base_model_draft_token_len); } WRAPPER_UNIMPLEMENTED(ctx); } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp index 972f46f00..e9d4e8fb3 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_stop_value_multi_ends.cpp @@ -21,13 +21,18 @@ namespace xpu3 { namespace plugin { template -__attribute__((global)) void -set_stop_value_multi_ends(bool *stop_flags, T *topk_ids, T *next_tokens, - const T *end_ids, const int *seq_lens, const int bs, - const int end_length, const bool beam_search, - const bool prefill_one_step_stop); -} // namespace plugin -} // namespace xpu3 +__attribute__((global)) void set_stop_value_multi_ends( + bool *stop_flags, + T *topk_ids, + T *next_tokens, + const T *end_ids, + const int *seq_lens, + const int bs, + const int end_length, + const bool beam_search, + const bool prefill_one_step_stop); +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { @@ -36,104 +41,143 @@ namespace plugin { template __inline__ bool is_in_end(const T id, const T *end_ids, int length) { - for (int i = 0; i < length; i++) { - if (id == end_ids[i]) { - return true; - } + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; } - return false; + } + return false; } template -static int cpu_wrapper(Context *ctx, bool *stop_flags, T *topk_ids, - T *next_tokens, const T *end_ids, const int *seq_lens, - const int bs, const int end_length, +static int cpu_wrapper(Context *ctx, + bool *stop_flags, + T *topk_ids, + T *next_tokens, + const T *end_ids, + const int *seq_lens, + const int bs, + const int end_length, const bool beam_search, const bool prefill_one_step_stop) { - for (int i = 0; i < bs; i++) { - if (prefill_one_step_stop) { - stop_flags[i] = true; - if (seq_lens[i] == 0) { - topk_ids[i] = -1; - } - next_tokens[i] = topk_ids[i]; + for (int i = 0; i < bs; i++) { + if (prefill_one_step_stop) { + stop_flags[i] = true; + if (seq_lens[i] == 0) { + topk_ids[i] = -1; + } + next_tokens[i] = topk_ids[i]; + } else { + if (stop_flags[i]) { + if (seq_lens[i] == 0) { + topk_ids[i] = -1; } else { - if (stop_flags[i]) { - if (seq_lens[i] == 0) { - topk_ids[i] = -1; - } else { - topk_ids[i] = end_ids[0]; - next_tokens[i] = end_ids[0]; - } - } else { - next_tokens[i] = topk_ids[i]; - } - if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) { - stop_flags[i] = true; - } + topk_ids[i] = end_ids[0]; + next_tokens[i] = end_ids[0]; } + } else { + next_tokens[i] = topk_ids[i]; + } + if (!beam_search && is_in_end(topk_ids[i], end_ids, end_length)) { + stop_flags[i] = true; + } } - return api::SUCCESS; + } + return api::SUCCESS; } template -static int xpu3_wrapper(Context *ctx, bool *stop_flags, T *topk_ids, - T *next_tokens, const T *end_ids, const int *seq_lens, - const int bs, const int end_length, +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + T *topk_ids, + T *next_tokens, + const T *end_ids, + const int *seq_lens, + const int bs, + const int end_length, const bool beam_search, const bool prefill_one_step_stop) { - using XPU_TID = typename XPUIndexType::type; - auto set_stop_value_multi_ends = - xpu3::plugin::set_stop_value_multi_ends; - set_stop_value_multi_ends<<ncluster(), 64, ctx->xpu_stream>>>( - stop_flags, reinterpret_cast(topk_ids), - reinterpret_cast(next_tokens), - reinterpret_cast(end_ids), seq_lens, bs, end_length, - beam_search, prefill_one_step_stop); - return api::SUCCESS; + using XPU_TID = typename XPUIndexType::type; + auto set_stop_value_multi_ends = + xpu3::plugin::set_stop_value_multi_ends; + set_stop_value_multi_ends<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + reinterpret_cast(topk_ids), + reinterpret_cast(next_tokens), + reinterpret_cast(end_ids), + seq_lens, + bs, + end_length, + beam_search, + prefill_one_step_stop); + return api::SUCCESS; } template -int set_stop_value_multi_ends(Context *ctx, bool *stop_flags, T *topk_ids, - T *next_tokens, const T *end_ids, - const int *seq_lens, const int bs, - const int end_length, const bool beam_search) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T); - WRAPPER_DUMP_PARAM5(ctx, stop_flags, topk_ids, next_tokens, end_ids, - seq_lens); - WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search); - WRAPPER_DUMP(ctx); - WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags); - WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids); - WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids); - WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens); - WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024 - bool prefill_one_step_stop = false; - if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { - // std::cout << "Your PATH is: " << env_p << '\n'; - if (env_p[0] == '1') { - prefill_one_step_stop = true; - } +int set_stop_value_multi_ends(Context *ctx, + bool *stop_flags, + T *topk_ids, + T *next_tokens, + const T *end_ids, + const int *seq_lens, + const int bs, + const int end_length, + const bool beam_search) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "set_stop_value_multi_ends", T); + WRAPPER_DUMP_PARAM5( + ctx, stop_flags, topk_ids, next_tokens, end_ids, seq_lens); + WRAPPER_DUMP_PARAM3(ctx, bs, end_length, beam_search); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, bool, bs, stop_flags); + WRAPPER_CHECK_PTR(ctx, T, bs, topk_ids); + WRAPPER_CHECK_PTR(ctx, T, end_length, end_ids); + WRAPPER_CHECK_PTR(ctx, T, bs, seq_lens); + WRAPPER_ASSERT_LE(ctx, end_length, 1024); // assume end_length <= 1024 + bool prefill_one_step_stop = false; + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { + // std::cout << "Your PATH is: " << env_p << '\n'; + if (env_p[0] == '1') { + prefill_one_step_stop = true; } - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, stop_flags, topk_ids, next_tokens, end_ids, - seq_lens, bs, end_length, beam_search, - prefill_one_step_stop); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, stop_flags, topk_ids, next_tokens, end_ids, - seq_lens, bs, end_length, beam_search, - prefill_one_step_stop); - } - WRAPPER_UNIMPLEMENTED(ctx); + } + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + topk_ids, + next_tokens, + end_ids, + seq_lens, + bs, + end_length, + beam_search, + prefill_one_step_stop); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + topk_ids, + next_tokens, + end_ids, + seq_lens, + bs, + end_length, + beam_search, + prefill_one_step_stop); + } + WRAPPER_UNIMPLEMENTED(ctx); } -template int set_stop_value_multi_ends( - Context *ctx, bool *stop_flags, int64_t *topk_ids, int64_t *next_tokens, - const int64_t *end_ids, const int *seq_lens, const int bs, - const int end_length, const bool beam_search); -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +template int set_stop_value_multi_ends(Context *ctx, + bool *stop_flags, + int64_t *topk_ids, + int64_t *next_tokens, + const int64_t *end_ids, + const int *seq_lens, + const int bs, + const int end_length, + const bool beam_search); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp index 39f39ff83..123b56b72 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_set_value_by_flags_and_idx.cpp @@ -12,128 +12,173 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { __attribute__((global)) void set_value_by_flags_and_idx( - const bool *stop_flags, int64_t *pre_ids_all, const int64_t *input_ids, - const int *seq_lens_encoder, const int *seq_lens_decoder, - const int64_t *step_idx, int bs, int length, int length_input_ids); + const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int length_input_ids); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int cpu_wrapper(Context *ctx, const bool *stop_flags, - int64_t *pre_ids_all, const int64_t *pre_ids, - const int64_t *step_idx, const int bs, +static int cpu_wrapper(Context *ctx, + const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *pre_ids, + const int64_t *step_idx, + const int bs, const int length) { - for (int i = 0; i < bs; i++) { - int64_t *pre_ids_all_now = pre_ids_all + i * length; - if (!stop_flags[i] && step_idx[i] >= 0) { - pre_ids_all_now[step_idx[i]] = pre_ids[i]; - } + for (int i = 0; i < bs; i++) { + int64_t *pre_ids_all_now = pre_ids_all + i * length; + if (!stop_flags[i] && step_idx[i] >= 0) { + pre_ids_all_now[step_idx[i]] = pre_ids[i]; } - return api::SUCCESS; + } + return api::SUCCESS; } -static int cpu_wrapper(Context *ctx, const bool *stop_flags, - int64_t *pre_ids_all, const int64_t *input_ids, - const int *seq_lens_encoder, const int *seq_lens_decoder, - const int64_t *step_idx, int bs, int length, +static int cpu_wrapper(Context *ctx, + const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, int length_input_ids) { - for (int i = 0; i < bs; i++) { - if (!stop_flags[i]) { - int64_t *pre_ids_all_now = pre_ids_all + i * length; - const int64_t *input_ids_now = input_ids + i * length_input_ids; - const int seq_len_dec = seq_lens_decoder[i]; - const int seq_len_enc = seq_lens_encoder[i]; - if (seq_len_dec == 0 && seq_len_enc == 0) - continue; - if (step_idx[i] >= 0) { - if (seq_len_enc > 0) { - // encoder, get last token accord to seq_lens_encoder - pre_ids_all_now[step_idx[i]] = - input_ids_now[seq_len_enc - 1]; - } else { - // decoder, get first token - pre_ids_all_now[step_idx[i]] = input_ids_now[0]; - } - } + for (int i = 0; i < bs; i++) { + if (!stop_flags[i]) { + int64_t *pre_ids_all_now = pre_ids_all + i * length; + const int64_t *input_ids_now = input_ids + i * length_input_ids; + const int seq_len_dec = seq_lens_decoder[i]; + const int seq_len_enc = seq_lens_encoder[i]; + if (seq_len_dec == 0 && seq_len_enc == 0) continue; + if (step_idx[i] >= 0) { + if (seq_len_enc > 0) { + // encoder, get last token accord to seq_lens_encoder + pre_ids_all_now[step_idx[i]] = input_ids_now[seq_len_enc - 1]; + } else { + // decoder, get first token + pre_ids_all_now[step_idx[i]] = input_ids_now[0]; } + } } - return api::SUCCESS; + } + return api::SUCCESS; } -static int xpu3_wrapper(Context *ctx, const bool *stop_flags, - int64_t *pre_ids_all, const int64_t *input_ids, +static int xpu3_wrapper(Context *ctx, + const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, const int *seq_lens_encoder, - const int *seq_lens_decoder, const int64_t *step_idx, - int bs, int length, int length_input_ids) { - using XPU_INT64 = typename XPUIndexType::type; - auto set_value_by_flags_and_idx_kernel = - xpu3::plugin::set_value_by_flags_and_idx; - set_value_by_flags_and_idx_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - stop_flags, reinterpret_cast(pre_ids_all), - reinterpret_cast(input_ids), seq_lens_encoder, - seq_lens_decoder, reinterpret_cast(step_idx), bs, - length, length_input_ids); - return api::SUCCESS; + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int length_input_ids) { + using XPU_INT64 = typename XPUIndexType::type; + auto set_value_by_flags_and_idx_kernel = + xpu3::plugin::set_value_by_flags_and_idx; + set_value_by_flags_and_idx_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + reinterpret_cast(pre_ids_all), + reinterpret_cast(input_ids), + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + bs, + length, + length_input_ids); + return api::SUCCESS; } -int set_value_by_flags_and_idx(Context *ctx, const bool *stop_flags, - int64_t *pre_ids_all, const int64_t *input_ids, +int set_value_by_flags_and_idx(Context *ctx, + const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, const int *seq_lens_encoder, const int *seq_lens_decoder, - const int64_t *step_idx, int bs, int length, + const int64_t *step_idx, + int bs, + int length, int length_input_ids) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t); - WRAPPER_DUMP_PARAM6(ctx, stop_flags, pre_ids_all, input_ids, - seq_lens_encoder, seq_lens_decoder, step_idx); - WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids); - WRAPPER_DUMP(ctx); - int64_t stop_flags_len = -1; - int64_t pre_ids_all_len = -1; - int64_t input_ids_len = -1; - int64_t seq_lens_encoder_len = -1; - int64_t seq_lens_decoder_len = -1; - int64_t step_idx_len = -1; - WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length}); - WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids}); - WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs}); - WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags); - WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all); - WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder); - WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder); - WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, stop_flags, pre_ids_all, input_ids, - seq_lens_encoder, seq_lens_decoder, step_idx, bs, - length, length_input_ids); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, stop_flags, pre_ids_all, input_ids, - seq_lens_encoder, seq_lens_decoder, step_idx, bs, - length, length_input_ids); - } - WRAPPER_UNIMPLEMENTED(ctx); + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "set_value_by_flags_and_idx", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx); + WRAPPER_DUMP_PARAM3(ctx, bs, length, length_input_ids); + WRAPPER_DUMP(ctx); + int64_t stop_flags_len = -1; + int64_t pre_ids_all_len = -1; + int64_t input_ids_len = -1; + int64_t seq_lens_encoder_len = -1; + int64_t seq_lens_decoder_len = -1; + int64_t step_idx_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &stop_flags_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &pre_ids_all_len, {bs, length}); + WRAPPER_CHECK_SHAPE(ctx, &input_ids_len, {bs, length_input_ids}); + WRAPPER_CHECK_SHAPE(ctx, &seq_lens_encoder_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &seq_lens_decoder_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &step_idx_len, {bs}); + WRAPPER_CHECK_PTR(ctx, int64_t, stop_flags_len, stop_flags); + WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_all_len, pre_ids_all); + WRAPPER_CHECK_PTR(ctx, int64_t, input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int, seq_lens_encoder_len, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, seq_lens_decoder_len, seq_lens_decoder); + WRAPPER_CHECK_PTR(ctx, int64_t, step_idx_len, step_idx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + length_input_ids); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + pre_ids_all, + input_ids, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + bs, + length, + length_input_ids); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp index 9170890c9..6698215f8 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/nn_token_penalty_multi_scores.cpp @@ -12,263 +12,367 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { template -__attribute__((global)) void -min_length_logits_process(T *logits, const int64_t *cur_len, - const int64_t *min_len, const int64_t *eos_token_id, - const int64_t bs, const int64_t length, - const int64_t length_id, const int64_t end_length); -__attribute__((global)) void -update_repeat_times(const int64_t *pre_ids, const int64_t *cur_len, - int *repeat_times, const int64_t bs, const int64_t length, - const int64_t length_id); +__attribute__((global)) void min_length_logits_process( + T *logits, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length); +__attribute__((global)) void update_repeat_times(const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id); template -__attribute__((global)) void -update_value_by_repeat_times(const int *repeat_times, const T *penalty_scores, - const T *frequency_score, const T *presence_score, - const float *temperatures, T *logits, - const int64_t bs, const int64_t length); +__attribute__((global)) void update_value_by_repeat_times( + const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int64_t bs, + const int64_t length); template __attribute__((global)) void update_value_by_repeat_times_simd( - const int *repeat_times, const T *penalty_scores, const T *frequency_score, - const T *presence_score, const float *temperatures, T *logits, - const int64_t bs, const int64_t length); + const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int64_t bs, + const int64_t length); template -__attribute__((global)) void -ban_bad_words(T *logits, const int64_t *bad_words_list, const int64_t bs, - const int64_t length, const int64_t bad_words_length); +__attribute__((global)) void ban_bad_words(T *logits, + const int64_t *bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -void update_repeat_times_cpu(const int64_t *pre_ids, const int64_t *cur_len, - int *repeat_times, const int64_t bs, - const int64_t length, const int64_t length_id) { - for (int64_t i = 0; i < bs; i++) { - if (cur_len[i] >= 0) { - for (int64_t j = 0; j < length_id; j++) { - int64_t id = pre_ids[i * length_id + j]; - if (id < 0 || id >= length) - continue; - repeat_times[i * length + id] += 1; - } - } +void update_repeat_times_cpu(const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id) { + for (int64_t i = 0; i < bs; i++) { + if (cur_len[i] >= 0) { + for (int64_t j = 0; j < length_id; j++) { + int64_t id = pre_ids[i * length_id + j]; + if (id < 0 || id >= length) continue; + repeat_times[i * length + id] += 1; + } } + } } -void ban_bad_words_cpu(float *logits, const int64_t *bad_words_list, - const int64_t bs, const int64_t length, +void ban_bad_words_cpu(float *logits, + const int64_t *bad_words_list, + const int64_t bs, + const int64_t length, const int64_t bad_words_length) { - for (int64_t i = 0; i < bs; i++) { - float *logits_now = logits + i * length; - for (int64_t j = 0; j < bad_words_length; j++) { - int64_t bad_words_token_id = bad_words_list[j]; - if (bad_words_token_id >= length || bad_words_token_id < 0) - continue; - logits_now[bad_words_token_id] = -1e10; - } + for (int64_t i = 0; i < bs; i++) { + float *logits_now = logits + i * length; + for (int64_t j = 0; j < bad_words_length; j++) { + int64_t bad_words_token_id = bad_words_list[j]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; } + } } template -static int cpu_wrapper(Context *ctx, const int64_t *pre_ids, T *logits, - const T *penalty_scores, const T *frequency_scores, - const T *presence_scores, const float *temperatures, - const int64_t *cur_len, const int64_t *min_len, - const int64_t *eos_token_id, const int64_t *bad_words, - const int64_t bs, const int64_t length, - const int64_t length_id, const int64_t end_length, +static int cpu_wrapper(Context *ctx, + const int64_t *pre_ids, + T *logits, + const T *penalty_scores, + const T *frequency_scores, + const T *presence_scores, + const float *temperatures, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t *bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, const int64_t length_bad_words) { - std::vector logitsfp32(bs * length); - std::vector penalty_scoresfp32(bs); - std::vector frequency_scoresfp32(bs); - std::vector presence_scoresfp32(bs); - std::vector repeat_times_buffer(bs * length, 0); - int ret = api::cast(ctx, logits, logitsfp32.data(), bs * length); - WRAPPER_ASSERT_SUCCESS(ctx, ret); - ret = - api::cast(ctx, penalty_scores, penalty_scoresfp32.data(), bs); - WRAPPER_ASSERT_SUCCESS(ctx, ret); - ret = api::cast(ctx, frequency_scores, - frequency_scoresfp32.data(), bs); - WRAPPER_ASSERT_SUCCESS(ctx, ret); - ret = api::cast(ctx, presence_scores, presence_scoresfp32.data(), - bs); - WRAPPER_ASSERT_SUCCESS(ctx, ret); - for (int64_t i = 0; i < bs; i++) { - if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) { - for (int64_t j = 0; j < end_length; j++) { - logitsfp32[i * length + eos_token_id[j]] = -1e4; - } - } + std::vector logitsfp32(bs * length); + std::vector penalty_scoresfp32(bs); + std::vector frequency_scoresfp32(bs); + std::vector presence_scoresfp32(bs); + std::vector repeat_times_buffer(bs * length, 0); + int ret = api::cast(ctx, logits, logitsfp32.data(), bs * length); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast(ctx, penalty_scores, penalty_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = api::cast( + ctx, frequency_scores, frequency_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + ret = + api::cast(ctx, presence_scores, presence_scoresfp32.data(), bs); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + for (int64_t i = 0; i < bs; i++) { + if (cur_len[i] >= 0 && cur_len[i] < min_len[i]) { + for (int64_t j = 0; j < end_length; j++) { + logitsfp32[i * length + eos_token_id[j]] = -1e4; + } } - int *repeat_times = &(repeat_times_buffer[0]); - update_repeat_times_cpu(pre_ids, cur_len, repeat_times, bs, length, - length_id); - for (int64_t i = 0; i < bs; i++) { - float alpha = penalty_scoresfp32[i]; - float beta = frequency_scoresfp32[i]; - float gamma = presence_scoresfp32[i]; - float temperature = temperatures[i]; - for (int64_t j = 0; j < length; j++) { - int times = repeat_times[i * length + j]; - float logit_now = logitsfp32[i * length + j]; - if (times != 0) { - logit_now = - logit_now < 0 ? logit_now * alpha : logit_now / alpha; - logit_now = logit_now - times * beta - gamma; - } - logitsfp32[i * length + j] = logit_now / temperature; - } + } + int *repeat_times = &(repeat_times_buffer[0]); + update_repeat_times_cpu( + pre_ids, cur_len, repeat_times, bs, length, length_id); + for (int64_t i = 0; i < bs; i++) { + float alpha = penalty_scoresfp32[i]; + float beta = frequency_scoresfp32[i]; + float gamma = presence_scoresfp32[i]; + float temperature = temperatures[i]; + for (int64_t j = 0; j < length; j++) { + int times = repeat_times[i * length + j]; + float logit_now = logitsfp32[i * length + j]; + if (times != 0) { + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - times * beta - gamma; + } + logitsfp32[i * length + j] = logit_now / temperature; } - if (bad_words && length_bad_words > 0) { - ban_bad_words_cpu(logitsfp32.data(), bad_words, bs, length, - length_bad_words); - } - ret = api::cast(ctx, logitsfp32.data(), logits, bs * length); - return ret; + } + if (bad_words && length_bad_words > 0) { + ban_bad_words_cpu( + logitsfp32.data(), bad_words, bs, length, length_bad_words); + } + ret = api::cast(ctx, logitsfp32.data(), logits, bs * length); + return ret; } template -static int xpu3_wrapper(Context *ctx, const int64_t *pre_ids, T *logits, - const T *penalty_scores, const T *frequency_scores, - const T *presence_scores, const float *temperatures, - const int64_t *cur_len, const int64_t *min_len, - const int64_t *eos_token_id, const int64_t *bad_words, - const int64_t bs, const int64_t length, - const int64_t length_id, const int64_t end_length, +static int xpu3_wrapper(Context *ctx, + const int64_t *pre_ids, + T *logits, + const T *penalty_scores, + const T *frequency_scores, + const T *presence_scores, + const float *temperatures, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t *bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, const int64_t length_bad_words) { - api::ctx_guard RAII_GUARD(ctx); - using XPU_INT64 = typename XPUIndexType::type; - auto min_length_logits_process_kernel = - xpu3::plugin::min_length_logits_process; - auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times; - auto update_value_by_repeat_times_kernel = - xpu3::plugin::update_value_by_repeat_times; - if (length % 16 == 0) { - update_value_by_repeat_times_kernel = - xpu3::plugin::update_value_by_repeat_times_simd; - } - auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words; + api::ctx_guard RAII_GUARD(ctx); + using XPU_INT64 = typename XPUIndexType::type; + auto min_length_logits_process_kernel = + xpu3::plugin::min_length_logits_process; + auto update_repeat_times_kernel = xpu3::plugin::update_repeat_times; + auto update_value_by_repeat_times_kernel = + xpu3::plugin::update_value_by_repeat_times; + if (length % 16 == 0) { + update_value_by_repeat_times_kernel = + xpu3::plugin::update_value_by_repeat_times_simd; + } + auto ban_bad_words_kernel = xpu3::plugin::ban_bad_words; - int *repeat_times = RAII_GUARD.alloc_l3_or_gm(bs * length); - WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times); - int ret = api::constant(ctx, repeat_times, bs * length, 0); - WRAPPER_ASSERT_SUCCESS(ctx, ret); + int *repeat_times = RAII_GUARD.alloc_l3_or_gm(bs * length); + WRAPPER_ASSERT_WORKSPACE(ctx, repeat_times); + int ret = api::constant(ctx, repeat_times, bs * length, 0); + WRAPPER_ASSERT_SUCCESS(ctx, ret); - update_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - reinterpret_cast(pre_ids), - reinterpret_cast(cur_len), repeat_times, bs, length, - length_id); - min_length_logits_process_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - logits, reinterpret_cast(cur_len), - reinterpret_cast(min_len), - reinterpret_cast(eos_token_id), bs, length, - length_id, end_length); - update_value_by_repeat_times_kernel<<ncluster(), 64, - ctx->xpu_stream>>>( - repeat_times, penalty_scores, frequency_scores, presence_scores, - temperatures, logits, bs, length); + update_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(pre_ids), + reinterpret_cast(cur_len), + repeat_times, + bs, + length, + length_id); + min_length_logits_process_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(cur_len), + reinterpret_cast(min_len), + reinterpret_cast(eos_token_id), + bs, + length, + length_id, + end_length); + update_value_by_repeat_times_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + repeat_times, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + logits, + bs, + length); - if (bad_words && length_bad_words > 0) { - ban_bad_words_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - logits, reinterpret_cast(bad_words), bs, length, - length_bad_words); - } - return api::SUCCESS; + if (bad_words && length_bad_words > 0) { + ban_bad_words_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + logits, + reinterpret_cast(bad_words), + bs, + length, + length_bad_words); + } + return api::SUCCESS; } template -int token_penalty_multi_scores( - Context *ctx, const int64_t *pre_ids, T *logits, const T *penalty_scores, - const T *frequency_scores, const T *presence_scores, - const float *temperatures, const int64_t *cur_len, const int64_t *min_len, - const int64_t *eos_token_id, const int64_t *bad_words, const int64_t bs, - const int64_t length, const int64_t length_id, const int64_t end_length, - const int64_t length_bad_words) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T); - WRAPPER_DUMP_PARAM6(ctx, pre_ids, logits, penalty_scores, frequency_scores, - presence_scores, temperatures); - WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id); - WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); - WRAPPER_DUMP(ctx); - // TODO(mayang02) shape check - int64_t pre_ids_len = -1; - int64_t logits_len = -1; - int64_t penalty_scores_len = -1; - int64_t frequency_scores_len = -1; - int64_t presence_scores_len = -1; - int64_t temperatures_len = -1; - int64_t cur_len_len = -1; - int64_t min_len_len = -1; - int64_t eos_token_id_len = -1; - int64_t bad_words_len = -1; - WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); - WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length}); - WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); - WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); - WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); - WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); - WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); - WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); - WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores); - WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores); - WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures); - WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len); - WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); - WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); - WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, pre_ids, logits, penalty_scores, - frequency_scores, presence_scores, temperatures, - cur_len, min_len, eos_token_id, bad_words, bs, - length, length_id, end_length, length_bad_words); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, pre_ids, logits, penalty_scores, - frequency_scores, presence_scores, temperatures, - cur_len, min_len, eos_token_id, bad_words, bs, - length, length_id, end_length, length_bad_words); - } - WRAPPER_UNIMPLEMENTED(ctx); +int token_penalty_multi_scores(Context *ctx, + const int64_t *pre_ids, + T *logits, + const T *penalty_scores, + const T *frequency_scores, + const T *presence_scores, + const float *temperatures, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t *bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "token_penalty_multi_scores", T); + WRAPPER_DUMP_PARAM6(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures); + WRAPPER_DUMP_PARAM3(ctx, cur_len, min_len, eos_token_id); + WRAPPER_DUMP_PARAM4(ctx, bs, length, length_id, end_length); + WRAPPER_DUMP(ctx); + // TODO(mayang02) shape check + int64_t pre_ids_len = -1; + int64_t logits_len = -1; + int64_t penalty_scores_len = -1; + int64_t frequency_scores_len = -1; + int64_t presence_scores_len = -1; + int64_t temperatures_len = -1; + int64_t cur_len_len = -1; + int64_t min_len_len = -1; + int64_t eos_token_id_len = -1; + int64_t bad_words_len = -1; + WRAPPER_CHECK_SHAPE(ctx, &pre_ids_len, {bs, length_id}); + WRAPPER_CHECK_SHAPE(ctx, &logits_len, {bs, length}); + WRAPPER_CHECK_SHAPE(ctx, &penalty_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &frequency_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &presence_scores_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &temperatures_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &cur_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &min_len_len, {bs}); + WRAPPER_CHECK_SHAPE(ctx, &eos_token_id_len, {end_length}); + WRAPPER_CHECK_SHAPE(ctx, &bad_words_len, {length_bad_words}); + WRAPPER_CHECK_PTR(ctx, int64_t, pre_ids_len, pre_ids); + WRAPPER_CHECK_PTR(ctx, T, logits_len, logits); + WRAPPER_CHECK_PTR(ctx, T, penalty_scores_len, penalty_scores); + WRAPPER_CHECK_PTR(ctx, T, frequency_scores_len, frequency_scores); + WRAPPER_CHECK_PTR(ctx, T, presence_scores_len, presence_scores); + WRAPPER_CHECK_PTR(ctx, float, temperatures_len, temperatures); + WRAPPER_CHECK_PTR(ctx, int64_t, cur_len_len, cur_len); + WRAPPER_CHECK_PTR(ctx, int64_t, min_len_len, min_len); + WRAPPER_CHECK_PTR(ctx, int64_t, eos_token_id_len, eos_token_id); + WRAPPER_CHECK_PTR(ctx, int64_t, bad_words_len, bad_words); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + bs, + length, + length_id, + end_length, + length_bad_words); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + cur_len, + min_len, + eos_token_id, + bad_words, + bs, + length, + length_id, + end_length, + length_bad_words); + } + WRAPPER_UNIMPLEMENTED(ctx); } -template int token_penalty_multi_scores( - Context *ctx, const int64_t *pre_ids, float *logits, - const float *penalty_scores, const float *frequency_scores, - const float *presence_scores, const float *temperatures, - const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, - const int64_t *bad_words, const int64_t bs, const int64_t length, - const int64_t length_id, const int64_t end_length, - const int64_t length_bad_words); +template int token_penalty_multi_scores(Context *ctx, + const int64_t *pre_ids, + float *logits, + const float *penalty_scores, + const float *frequency_scores, + const float *presence_scores, + const float *temperatures, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t *bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, + const int64_t length_bad_words); template int token_penalty_multi_scores( - Context *ctx, const int64_t *pre_ids, float16 *logits, - const float16 *penalty_scores, const float16 *frequency_scores, - const float16 *presence_scores, const float *temperatures, - const int64_t *cur_len, const int64_t *min_len, const int64_t *eos_token_id, - const int64_t *bad_words, const int64_t bs, const int64_t length, - const int64_t length_id, const int64_t end_length, + Context *ctx, + const int64_t *pre_ids, + float16 *logits, + const float16 *penalty_scores, + const float16 *frequency_scores, + const float16 *presence_scores, + const float *temperatures, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t *bad_words, + const int64_t bs, + const int64_t length, + const int64_t length_id, + const int64_t end_length, const int64_t length_bad_words); -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/quant2d_per_channel.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/quant2d_per_channel.cpp index 225aff772..00b4347bf 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/quant2d_per_channel.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/quant2d_per_channel.cpp @@ -20,21 +20,18 @@ namespace xpu3 { namespace plugin { template -__attribute__((global)) void -quant2d_per_channel_cluster(const TX *x, const TSCALE *scale, TY *y, int64_t m, - int64_t n); +__attribute__((global)) void quant2d_per_channel_cluster( + const TX *x, const TSCALE *scale, TY *y, int64_t m, int64_t n); template -__attribute__((global)) void -quant2d_per_channel_cached(const TX *input, TY *output, TSCALE *scale, - int64_t m, int64_t n); +__attribute__((global)) void quant2d_per_channel_cached( + const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n); template -__attribute__((global)) void quant2d_per_channel_bign(const TX *input, - TY *output, TSCALE *scale, - int64_t m, int64_t n); -} // namespace plugin -} // namespace xpu3 +__attribute__((global)) void quant2d_per_channel_bign( + const TX *input, TY *output, TSCALE *scale, int64_t m, int64_t n); +} // namespace plugin +} // namespace xpu3 namespace api = baidu::xpu::api; @@ -43,11 +40,17 @@ namespace xpu { namespace api { namespace plugin { -template ::value, TY>::type *ptr = nullptr> -int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale, - TY *y, int64_t m, int64_t n) { +int cpu_wrapper_input_scale(api::Context *ctx, + const TX *x, + const TSCALE *scale, + TY *y, + int64_t m, + int64_t n) { float absmax = 1e-30f; for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { @@ -78,11 +81,13 @@ static float16 quant_int4(float x, float scale) { return (float16)std::min(static_cast(r), 7.f); } -template ::value, TY>::type *ptr = nullptr> -int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale, - TY *y, int m, int n) { +int cpu_wrapper_input_scale( + api::Context *ctx, const TX *x, const TSCALE *scale, TY *y, int m, int n) { int8_t *y_ptr = reinterpret_cast(y); float t1, t2; for (int i = 0; i < m; ++i) { @@ -109,11 +114,17 @@ int cpu_wrapper_input_scale(api::Context *ctx, const TX *x, const TSCALE *scale, return api::SUCCESS; } -template ::value, TY>::type *ptr = nullptr> -int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, - TY *y, int64_t m, int64_t n) { +int cpu_wrapper_output_scale(api::Context *ctx, + const TX *x, + TSCALE *scale, + TY *y, + int64_t m, + int64_t n) { int64_t i, j; for (j = 0; j < n; ++j) { float absmax = 1e-30f; @@ -129,11 +140,13 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, return api::SUCCESS; } -template ::value, TY>::type *ptr = nullptr> -int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, - TY *y, int m, int n) { +int cpu_wrapper_output_scale( + api::Context *ctx, const TX *x, TSCALE *scale, TY *y, int m, int n) { int8_t *y_ptr = reinterpret_cast(y); float t1, t2, absmax_1, absmax_2, act_scale_1, act_scale_2; for (int j = 0; j < n; j += 2) { @@ -173,18 +186,28 @@ int cpu_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, } template -int xpu3_wrapper_input_scale(api::Context *ctx, const TX *x, - const TSCALE *scale, TY *y, int64_t m, int64_t n) { +int xpu3_wrapper_input_scale(api::Context *ctx, + const TX *x, + const TSCALE *scale, + TY *y, + int64_t m, + int64_t n) { auto func = xpu3::plugin::quant2d_per_channel_cluster; func<<ncluster(), 64, ctx->xpu_stream>>>(x, scale, y, m, n); return api::SUCCESS; } -template ::value, TY>::type * = nullptr> -int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, - TY *y, int64_t m, int64_t n) { +int xpu3_wrapper_output_scale(api::Context *ctx, + const TX *x, + TSCALE *scale, + TY *y, + int64_t m, + int64_t n) { int64_t channel_size = m * sizeof(TX); int64_t cluster_n = (n + ctx->ncluster() - 1) / ctx->ncluster(); auto func = xpu3::plugin::quant2d_per_channel_bign; @@ -210,19 +233,30 @@ int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, func<<ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n); return api::SUCCESS; } -template ::value, TY>::type * = nullptr> -int xpu3_wrapper_output_scale(api::Context *ctx, const TX *x, TSCALE *scale, - TY *y, int64_t m, int64_t n) { +int xpu3_wrapper_output_scale(api::Context *ctx, + const TX *x, + TSCALE *scale, + TY *y, + int64_t m, + int64_t n) { auto func = xpu3::plugin::quant2d_per_channel_bign; func<<ncluster(), 64, ctx->xpu_stream>>>(x, y, scale, m, n); return api::SUCCESS; } template -int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in, - TY *y, TSCALE *scale_out, int64_t m, int64_t n) { +int quant2d_per_channel(api::Context *ctx, + const TX *x, + const TSCALE *scale_in, + TY *y, + TSCALE *scale_out, + int64_t m, + int64_t n) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T3(ctx, "quant2d_per_channel", TX, TSCALE, TY); WRAPPER_DUMP_PARAM4(ctx, x, scale_in, y, scale_out); @@ -251,20 +285,24 @@ int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in, } if (ctx->dev().type() == api::kXPU3) { if (scale_in != nullptr) { - return xpu3_wrapper_input_scale(ctx, x, scale_in, y, m, - n); + return xpu3_wrapper_input_scale( + ctx, x, scale_in, y, m, n); } - return xpu3_wrapper_output_scale(ctx, x, scale_out, y, m, - n); + return xpu3_wrapper_output_scale( + ctx, x, scale_out, y, m, n); } WRAPPER_UNIMPLEMENTED(ctx); return 0; } -#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \ - template int quant2d_per_channel( \ - api::Context *, const TX *, const TSCALE *, TY *, TSCALE *, int64_t, \ - int64_t); +#define INSTANTIATION_QUANT2D_PER_CHANNEL(TX, TSCALE, TY) \ + template int quant2d_per_channel(api::Context *, \ + const TX *, \ + const TSCALE *, \ + TY *, \ + TSCALE *, \ + int64_t, \ + int64_t); INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int8_t); INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int8_t); @@ -274,7 +312,7 @@ INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float16, int4_t); INSTANTIATION_QUANT2D_PER_CHANNEL(float16, float, int4_t); INSTANTIATION_QUANT2D_PER_CHANNEL(float, float, int4_t); INSTANTIATION_QUANT2D_PER_CHANNEL(bfloat16, float, int4_t); -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_block.cpp index f8292d930..284200b08 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_block.cpp @@ -12,28 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { -__attribute__((global)) void -recover_block(int *recover_block_list, // [bsz] - int *recover_len, bool *stop_flags, int *seq_lens_this_time, - const int *ori_seq_lens_encoder, int *seq_lens_encoder, - const int *seq_lens_decoder, int *block_tables, int *free_list, - int *free_list_len, int64_t *input_ids, const int64_t *pre_ids, - const int64_t *step_idx, const int *encoder_block_lens, - const int *used_list_len, const int64_t *next_tokens, - const int64_t *first_token_ids, const int bsz, - const int block_num_per_seq, const int length, - const int pre_id_length); +__attribute__((global)) void recover_block(int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { @@ -41,125 +51,207 @@ namespace api { namespace plugin { static int cpu_wrapper(Context *ctx, - int *recover_block_list, // [bsz] - int *recover_len, bool *stop_flags, - int *seq_lens_this_time, const int *ori_seq_lens_encoder, - int *seq_lens_encoder, const int *seq_lens_decoder, - int *block_tables, int *free_list, int *free_list_len, - int64_t *input_ids, const int64_t *pre_ids, - const int64_t *step_idx, const int *encoder_block_lens, - const int *used_list_len, const int64_t *next_tokens, - const int64_t *first_token_ids, const int bsz, - const int block_num_per_seq, const int length, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, const int pre_id_length) { - for (int bid = 0; bid < recover_len[0]; bid++) { - const int recover_id = recover_block_list[bid]; - const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; - const int step_idx_now = step_idx[recover_id]; - const int seq_len = ori_seq_len_encoder + step_idx_now; - const int encoder_block_len = encoder_block_lens[recover_id]; - const int decoder_used_len = used_list_len[recover_id]; - int *block_table_now = block_tables + recover_id * block_num_per_seq; - int64_t *input_ids_now = input_ids + recover_id * length; - const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + for (int bid = 0; bid < recover_len[0]; bid++) { + const int recover_id = recover_block_list[bid]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + int *block_table_now = block_tables + recover_id * block_num_per_seq; + int64_t *input_ids_now = input_ids + recover_id * length; + const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; - seq_lens_this_time[recover_id] = seq_len; - seq_lens_encoder[recover_id] = seq_len; - stop_flags[recover_id] = false; - input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens - input_ids_now[0] = - first_token_ids[recover_id]; // set first prompt token - int ori_free_list_len = free_list_len[0]; - free_list_len[0] -= decoder_used_len; + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens + input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token + int ori_free_list_len = free_list_len[0]; + free_list_len[0] -= decoder_used_len; - // 恢复block table - for (int i = 0; i < decoder_used_len; i++) { - block_table_now[encoder_block_len + i] = - free_list[ori_free_list_len - i - 1]; - } - // 恢复input_ids - for (int i = 0; i < step_idx_now - 1; i++) { - input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; - } + // 恢复block table + for (int i = 0; i < decoder_used_len; i++) { + block_table_now[encoder_block_len + i] = + free_list[ori_free_list_len - i - 1]; } - recover_len[0] = 0; - return api::SUCCESS; + // 恢复input_ids + for (int i = 0; i < step_idx_now - 1; i++) { + input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; + } + } + recover_len[0] = 0; + return api::SUCCESS; } static int xpu3_wrapper(Context *ctx, - int *recover_block_list, // [bsz] - int *recover_len, bool *stop_flags, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, int *seq_lens_this_time, - const int *ori_seq_lens_encoder, int *seq_lens_encoder, - const int *seq_lens_decoder, int *block_tables, - int *free_list, int *free_list_len, int64_t *input_ids, - const int64_t *pre_ids, const int64_t *step_idx, - const int *encoder_block_lens, const int *used_list_len, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, const int64_t *next_tokens, - const int64_t *first_token_ids, const int bsz, - const int block_num_per_seq, const int length, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, const int pre_id_length) { - using XPU_INT64 = typename XPUIndexType::type; - auto recover_block_kernel = xpu3::plugin::recover_block; - recover_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - recover_block_list, // [bsz] - recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder, seq_lens_decoder, block_tables, free_list, - free_list_len, reinterpret_cast(input_ids), - reinterpret_cast(pre_ids), - reinterpret_cast(step_idx), encoder_block_lens, - used_list_len, reinterpret_cast(next_tokens), - reinterpret_cast(first_token_ids), bsz, - block_num_per_seq, length, pre_id_length); - return api::SUCCESS; + using XPU_INT64 = typename XPUIndexType::type; + auto recover_block_kernel = xpu3::plugin::recover_block; + recover_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + reinterpret_cast(input_ids), + reinterpret_cast(pre_ids), + reinterpret_cast(step_idx), + encoder_block_lens, + used_list_len, + reinterpret_cast(next_tokens), + reinterpret_cast(first_token_ids), + bsz, + block_num_per_seq, + length, + pre_id_length); + return api::SUCCESS; } int recover_block(Context *ctx, - int *recover_block_list, // [bsz] - int *recover_len, bool *stop_flags, int *seq_lens_this_time, - const int *ori_seq_lens_encoder, int *seq_lens_encoder, - const int *seq_lens_decoder, int *block_tables, - int *free_list, int *free_list_len, int64_t *input_ids, - const int64_t *pre_ids, const int64_t *step_idx, - const int *encoder_block_lens, const int *used_list_len, - const int64_t *next_tokens, const int64_t *first_token_ids, - const int bsz, const int block_num_per_seq, const int length, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, const int pre_id_length) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float); - WRAPPER_DUMP_PARAM6(ctx, recover_block_list, recover_len, stop_flags, - seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder); - WRAPPER_DUMP_PARAM6(ctx, seq_lens_decoder, block_tables, free_list, - free_list_len, input_ids, pre_ids); - WRAPPER_DUMP_PARAM5(ctx, step_idx, encoder_block_lens, used_list_len, - next_tokens, first_token_ids); - WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper( - ctx, - recover_block_list, // [bsz] - recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder, seq_lens_decoder, block_tables, free_list, - free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens, - used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq, - length, pre_id_length); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper( - ctx, - recover_block_list, // [bsz] - recover_len, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder, seq_lens_decoder, block_tables, free_list, - free_list_len, input_ids, pre_ids, step_idx, encoder_block_lens, - used_list_len, next_tokens, first_token_ids, bsz, block_num_per_seq, - length, pre_id_length); - } - WRAPPER_UNIMPLEMENTED(ctx); + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_block", float); + WRAPPER_DUMP_PARAM6(ctx, + recover_block_list, + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder); + WRAPPER_DUMP_PARAM6(ctx, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids); + WRAPPER_DUMP_PARAM5(ctx, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids); + WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp index 1ed700897..8c6217dc7 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/recover_decode_task.cpp @@ -12,96 +12,102 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { -__attribute__((global)) void -recover_decode_task(bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size); +__attribute__((global)) void recover_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int xpu3_wrapper(Context *ctx, bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size) { - using XPU_INT64 = typename XPUIndexType::type; - auto recover_decode_task = xpu3::plugin::recover_decode_task; - recover_decode_task<<ncluster(), 64, ctx->xpu_stream>>>( - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_seq_lens_decoder, - block_tables, - is_block_step, - bsz, - block_num_per_seq, - block_size); - return api::SUCCESS; +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + using XPU_INT64 = typename XPUIndexType::type; + auto recover_decode_task = xpu3::plugin::recover_decode_task; + recover_decode_task<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + block_tables, + is_block_step, + bsz, + block_num_per_seq, + block_size); + return api::SUCCESS; } -int recover_decode_task(Context *ctx, bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int); - WRAPPER_DUMP_PARAM5(ctx, stop_flags, seq_lens_this_time, - seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder); - WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step); - WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - assert(false); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_seq_lens_decoder, - block_tables, - is_block_step, - bsz, - block_num_per_seq, - block_size); - } - WRAPPER_UNIMPLEMENTED(ctx); +int recover_decode_task(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int); + WRAPPER_DUMP_PARAM5(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder); + WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step); + WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + assert(false); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + block_tables, + is_block_step, + bsz, + block_num_per_seq, + block_size); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp index d4c52293c..f719ed9fe 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp @@ -18,18 +18,17 @@ namespace xpu3 { namespace plugin { template -__attribute__((global)) void text_image_gather_scatter( - T* input, - T* text_input, - T* image_input, - int* token_type_ids, - int* text_index, - int* image_index, - int64_t token_num, - int64_t text_token_num, - int64_t image_token_num, - int64_t hidden_size, - bool is_scatter); +__attribute__((global)) void text_image_gather_scatter(T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter); } // namespace plugin } // namespace xpu3 @@ -41,18 +40,17 @@ namespace plugin { template static int cpu_wrapper( Context* ctx, - T* input, // shape [token_num, hidden_size] - T* text_input, // shape [text_token_num, hidden_size] - T* image_input, // shape [image_token_num, hidden_size] - int* token_type_ids,// shape [token_num], 0 for text, 1 for image - int* text_index, // shape [token_num], mapping from input to text_input - int* image_index, // shape [token_num], mapping from input to image_input + T* input, // shape [token_num, hidden_size] + T* text_input, // shape [text_token_num, hidden_size] + T* image_input, // shape [image_token_num, hidden_size] + int* token_type_ids, // shape [token_num], 0 for text, 1 for image + int* text_index, // shape [token_num], mapping from input to text_input + int* image_index, // shape [token_num], mapping from input to image_input int64_t token_num, int64_t text_token_num, int64_t image_token_num, int64_t hidden_size, bool is_scatter) { - if (is_scatter) { // Scatter mode: input -> text_input/image_input for (int64_t i = 0; i < token_num; i++) { @@ -106,36 +104,42 @@ static int cpu_wrapper( } template -static int xpu3_wrapper( - Context* ctx, - T* input, - T* text_input, - T* image_input, - int* token_type_ids, - int* text_index, - int* image_index, - int64_t token_num, - int64_t text_token_num, - int64_t image_token_num, - int64_t hidden_size, - bool is_scatter) { - xpu3::plugin::text_image_gather_scatter <<ncluster(), 64, ctx->xpu_stream>>>( - input, text_input, image_input, token_type_ids, text_index, image_index, - token_num, text_token_num, image_token_num, hidden_size, is_scatter - ); +static int xpu3_wrapper(Context* ctx, + T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + xpu3::plugin::text_image_gather_scatter + <<ncluster(), 64, ctx->xpu_stream>>>(input, + text_input, + image_input, + token_type_ids, + text_index, + image_index, + token_num, + text_token_num, + image_token_num, + hidden_size, + is_scatter); return api::SUCCESS; } - template int text_image_gather_scatter( Context* ctx, - T* input, // shape [token_num, hidden_size] - T* text_input, // shape [text_token_num, hidden_size] - T* image_input, // shape [image_token_num, hidden_size] - int* token_type_ids,// shape [token_num], 0 for text, 1 for image - int* text_index, // shape [token_num], mapping from input to text_input - int* image_index, // shape [token_num], mapping from input to image_input + T* input, // shape [token_num, hidden_size] + T* text_input, // shape [text_token_num, hidden_size] + T* image_input, // shape [image_token_num, hidden_size] + int* token_type_ids, // shape [token_num], 0 for text, 1 for image + int* text_index, // shape [token_num], mapping from input to text_input + int* image_index, // shape [token_num], mapping from input to image_input int64_t token_num, int64_t text_token_num, int64_t image_token_num, @@ -143,14 +147,23 @@ int text_image_gather_scatter( bool is_scatter) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T); - WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index); - WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter); + WRAPPER_DUMP_PARAM6(ctx, + input, + text_input, + image_input, + token_type_ids, + text_index, + image_index); + WRAPPER_DUMP_PARAM5( + ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter); WRAPPER_DUMP(ctx); WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input); - if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size] + if (text_token_num != + 0) { // avoiding text_input tensor with shape [0, hidden_size] WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input); } - if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size] + if (image_token_num != + 0) { // avoiding image_input tensor with shape [0, hidden_size] WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input); } WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids); @@ -159,23 +172,48 @@ int text_image_gather_scatter( WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num); if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper( - ctx, input, text_input, image_input, token_type_ids, text_index, image_index, - token_num, text_token_num, image_token_num, hidden_size, is_scatter - ); + return cpu_wrapper(ctx, + input, + text_input, + image_input, + token_type_ids, + text_index, + image_index, + token_num, + text_token_num, + image_token_num, + hidden_size, + is_scatter); } if (ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper( - ctx, input, text_input, image_input, token_type_ids, text_index, image_index, - token_num, text_token_num, image_token_num, hidden_size, is_scatter - ); + return xpu3_wrapper(ctx, + input, + text_input, + image_input, + token_type_ids, + text_index, + image_index, + token_num, + text_token_num, + image_token_num, + hidden_size, + is_scatter); } WRAPPER_UNIMPLEMENTED(ctx); } - -template int text_image_gather_scatter( - Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool); +template int text_image_gather_scatter(Context*, + bfloat16*, + bfloat16*, + bfloat16*, + int*, + int*, + int*, + const int64_t, + const int64_t, + const int64_t, + const int64_t, + bool); } // namespace plugin } // namespace api } // namespace xpu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp index 3a2cd44c4..0adab1dd9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp @@ -17,10 +17,11 @@ namespace xpu3 { namespace plugin { -__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x - int* text_index, // y1 - int* image_index, // y2 - const int64_t token_num); +__attribute__((global)) void text_image_index_out_kernel( + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num); } // namespace plugin } // namespace xpu3 @@ -30,69 +31,54 @@ namespace api { namespace plugin { static int cpu_wrapper(Context* ctx, - const int* token_type_ids, // x - int* text_index, // y1 - int* image_index, // y2 + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 const int64_t token_num) { - int text_count = 0; + int text_count = 0; int image_count = 0; for (int64_t i = 0; i < token_num; ++i) { - if (token_type_ids[i] == 0) { - text_index[i] = text_count; - ++text_count; - } else { - image_index[i] = image_count; - ++image_count; - } + if (token_type_ids[i] == 0) { + text_index[i] = text_count; + ++text_count; + } else { + image_index[i] = image_count; + ++image_count; + } } return api::SUCCESS; - } static int xpu3_wrapper(Context* ctx, - const int* token_type_ids, // x - int* text_index, // y1 - int* image_index, // y2 + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 const int64_t token_num) { - xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>( - token_type_ids, - text_index, - image_index, - token_num); + token_type_ids, text_index, image_index, token_num); return api::SUCCESS; } int text_image_index_out(Context* ctx, - const int* token_type_ids, // x - int* text_index, // y1 - int* image_index, // y2 + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 const int64_t token_num) { - WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int); - WRAPPER_DUMP_PARAM4( - ctx, token_type_ids, text_index, image_index, token_num); + WRAPPER_DUMP_PARAM4(ctx, token_type_ids, text_index, image_index, token_num); WRAPPER_DUMP(ctx); WRAPPER_ASSERT_GT(ctx, token_num, 0); WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids); WRAPPER_CHECK_PTR(ctx, int, token_num, text_index); WRAPPER_CHECK_PTR(ctx, int, token_num, image_index); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, - token_type_ids, - text_index, - image_index, - token_num); + return cpu_wrapper(ctx, token_type_ids, text_index, image_index, token_num); } else if (ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, - token_type_ids, - text_index, - image_index, - token_num); + return xpu3_wrapper( + ctx, token_type_ids, text_index, image_index, token_num); } WRAPPER_UNIMPLEMENTED(ctx); } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs.cpp index a12edfd1d..248d11349 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs.cpp @@ -12,108 +12,162 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { -__attribute__((global)) void -update_inputs(bool *not_need_stop, int *seq_lens_this_time, - int *seq_lens_encoder, int *seq_lens_decoder, int64_t *input_ids, - const int64_t *stop_nums, const bool *stop_flags, - const bool *is_block_step, const int64_t *next_tokens, - const int bsz, const int max_bsz, const int input_ids_stride); +__attribute__((global)) void update_inputs(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int cpu_wrapper(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, int *seq_lens_encoder, - int *seq_lens_decoder, int64_t *input_ids, - const int64_t *stop_nums, const bool *stop_flags, - const bool *is_block_step, const int64_t *next_tokens, - const int bsz, const int max_bsz, +static int cpu_wrapper(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, const int input_ids_stride) { - std::vector stop_flag_now_int(max_bsz, 1); - for (int i = 0; i < bsz; i++) { - bool stop_flags_now = stop_flags[i]; - stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now; - const int seq_len_encoder = seq_lens_encoder[i]; - const int seq_len_decoder = seq_lens_decoder[i]; + std::vector stop_flag_now_int(max_bsz, 1); + for (int i = 0; i < bsz; i++) { + bool stop_flags_now = stop_flags[i]; + stop_flag_now_int[i] = is_block_step[i] ? 0 : stop_flags_now; + const int seq_len_encoder = seq_lens_encoder[i]; + const int seq_len_decoder = seq_lens_decoder[i]; - seq_lens_decoder[i] = - stop_flags[i] ? 0 - : (seq_len_decoder == 0 ? seq_len_encoder - : seq_len_decoder + 1); + seq_lens_decoder[i] = + stop_flags[i] + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); - seq_lens_this_time[i] = stop_flags[i] ? 0 : 1; - seq_lens_encoder[i] = 0; - int64_t *input_ids_now = input_ids + i * input_ids_stride; - input_ids_now[0] = next_tokens[i]; - } - int64_t stop_sum = 0; - for (size_t i = 0; i < stop_flag_now_int.size(); i++) { - stop_sum += stop_flag_now_int[i]; - } - not_need_stop[0] = stop_sum < stop_nums[0]; - return api::SUCCESS; + seq_lens_this_time[i] = stop_flags[i] ? 0 : 1; + seq_lens_encoder[i] = 0; + int64_t *input_ids_now = input_ids + i * input_ids_stride; + input_ids_now[0] = next_tokens[i]; + } + int64_t stop_sum = 0; + for (size_t i = 0; i < stop_flag_now_int.size(); i++) { + stop_sum += stop_flag_now_int[i]; + } + not_need_stop[0] = stop_sum < stop_nums[0]; + return api::SUCCESS; } -static int xpu3_wrapper(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, int *seq_lens_encoder, - int *seq_lens_decoder, int64_t *input_ids, - const int64_t *stop_nums, const bool *stop_flags, - const bool *is_block_step, const int64_t *next_tokens, - const int bsz, const int max_bsz, +static int xpu3_wrapper(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, const int input_ids_stride) { - using XPU_INT64 = typename XPUIndexType::type; - auto update_inputs = xpu3::plugin::update_inputs; - update_inputs<<ncluster(), 64, ctx->xpu_stream>>>( - not_need_stop, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, - reinterpret_cast(input_ids), - reinterpret_cast(stop_nums), stop_flags, - is_block_step, reinterpret_cast(next_tokens), bsz, - max_bsz, input_ids_stride); - return api::SUCCESS; + using XPU_INT64 = typename XPUIndexType::type; + auto update_inputs = xpu3::plugin::update_inputs; + update_inputs<<ncluster(), 64, ctx->xpu_stream>>>( + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(input_ids), + reinterpret_cast(stop_nums), + stop_flags, + is_block_step, + reinterpret_cast(next_tokens), + bsz, + max_bsz, + input_ids_stride); + return api::SUCCESS; } -int update_inputs(Context *ctx, bool *not_need_stop, int *seq_lens_this_time, - int *seq_lens_encoder, int *seq_lens_decoder, - int64_t *input_ids, const int64_t *stop_nums, - const bool *stop_flags, const bool *is_block_step, - const int64_t *next_tokens, const int bsz, const int max_bsz, +int update_inputs(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, const int input_ids_stride) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int); - WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time, - seq_lens_encoder, seq_lens_decoder, input_ids); - WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens); - WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, not_need_stop, seq_lens_this_time, - seq_lens_encoder, seq_lens_decoder, input_ids, - stop_nums, stop_flags, is_block_step, next_tokens, - bsz, max_bsz, input_ids_stride); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, not_need_stop, seq_lens_this_time, - seq_lens_encoder, seq_lens_decoder, input_ids, - stop_nums, stop_flags, is_block_step, next_tokens, - bsz, max_bsz, input_ids_stride); - } - WRAPPER_UNIMPLEMENTED(ctx); + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs", int); + WRAPPER_DUMP_PARAM5(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids); + WRAPPER_DUMP_PARAM4(ctx, stop_nums, stop_flags, is_block_step, next_tokens); + WRAPPER_DUMP_PARAM3(ctx, bsz, max_bsz, input_ids_stride); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + input_ids, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp index ce97e91d7..7fe1772c4 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp @@ -12,138 +12,146 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { -__attribute__((global)) void -update_inputs_v1(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size); +__attribute__((global)) void update_inputs_v1(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int xpu3_wrapper(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size) { - using XPU_INT64 = typename XPUIndexType::type; - auto update_inputs_v1 = xpu3::plugin::update_inputs_v1; - // kernel 内要做 reduce,只能用 1 个 cluster - update_inputs_v1<<<1, 64, ctx->xpu_stream>>>( - not_need_stop, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_seq_lens_decoder, - reinterpret_cast(prompt_lens), - reinterpret_cast(topk_ids), - reinterpret_cast(input_ids), - block_tables, - reinterpret_cast(stop_nums), - stop_flags, - is_block_step, - reinterpret_cast(next_tokens), - bsz, - max_bsz, - input_ids_stride, - block_num_per_seq, - block_size); - return api::SUCCESS; +static int xpu3_wrapper(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + using XPU_INT64 = typename XPUIndexType::type; + auto update_inputs_v1 = xpu3::plugin::update_inputs_v1; + // kernel 内要做 reduce,只能用 1 个 cluster + update_inputs_v1<<<1, 64, ctx->xpu_stream>>>( + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + reinterpret_cast(prompt_lens), + reinterpret_cast(topk_ids), + reinterpret_cast(input_ids), + block_tables, + reinterpret_cast(stop_nums), + stop_flags, + is_block_step, + reinterpret_cast(next_tokens), + bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + return api::SUCCESS; } -int update_inputs_v1(Context *ctx, bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride, - const int block_num_per_seq, - const int block_size) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int); - WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time, - seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder); - WRAPPER_DUMP_PARAM5(ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums); - WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens); - WRAPPER_DUMP_PARAM5(ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - assert(false); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, not_need_stop, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_seq_lens_decoder, - prompt_lens, - topk_ids, - input_ids, - block_tables, - stop_nums, - stop_flags, - is_block_step, - next_tokens, - bsz, - max_bsz, - input_ids_stride, - block_num_per_seq, - block_size); - } - WRAPPER_UNIMPLEMENTED(ctx); +int update_inputs_v1(Context *ctx, + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int); + WRAPPER_DUMP_PARAM5(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder); + WRAPPER_DUMP_PARAM5( + ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums); + WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens); + WRAPPER_DUMP_PARAM5( + ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + assert(false); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + not_need_stop, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_seq_lens_decoder, + prompt_lens, + topk_ids, + input_ids, + block_tables, + stop_nums, + stop_flags, + is_block_step, + next_tokens, + bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h index 596e3b2e6..1e94e0824 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h @@ -22,32 +22,25 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include -#include +#include #include +#include +#include +#include +#include +#include +#include #include +#include #include -#include #include +#include +#include +#include +#include #include "kvcache_rdma.h" #include "util.h" @@ -60,115 +53,115 @@ /// @brief IB device information structure struct IbDeviceInfo { - int device; - uint64_t guid; - enum ibv_mtu mtu; - uint64_t busid; - uint8_t port; - uint8_t link; - uint8_t active_mtu; - int speed; - ibv_context* context; - char devName[64]; - int realPort; - int maxQp; + int device; + uint64_t guid; + enum ibv_mtu mtu; + uint64_t busid; + uint8_t port; + uint8_t link; + uint8_t active_mtu; + int speed; + ibv_context* context; + char devName[64]; + int realPort; + int maxQp; }; /// @brief Queue Pair information for RDMA struct QpInfo { - uint32_t lid; - uint32_t qpn; - uint32_t psn; - union ibv_gid gid; - enum ibv_mtu mtu; + uint32_t lid; + uint32_t qpn; + uint32_t psn; + union ibv_gid gid; + enum ibv_mtu mtu; - /// @brief Serialize QP info to buffer - void serialize(char* buffer) const { - uint32_t* intBuffer = reinterpret_cast(buffer); - intBuffer[0] = htonl(lid); - intBuffer[1] = htonl(qpn); - intBuffer[2] = htonl(psn); - memcpy(buffer + 12, gid.raw, sizeof(gid.raw)); - intBuffer[7] = htonl(static_cast(mtu)); - } + /// @brief Serialize QP info to buffer + void serialize(char* buffer) const { + uint32_t* intBuffer = reinterpret_cast(buffer); + intBuffer[0] = htonl(lid); + intBuffer[1] = htonl(qpn); + intBuffer[2] = htonl(psn); + memcpy(buffer + 12, gid.raw, sizeof(gid.raw)); + intBuffer[7] = htonl(static_cast(mtu)); + } - /// @brief Deserialize QP info from buffer - void deserialize(const char* buffer) { - const uint32_t* intBuffer = reinterpret_cast(buffer); - lid = ntohl(intBuffer[0]); - qpn = ntohl(intBuffer[1]); - psn = ntohl(intBuffer[2]); - memcpy(gid.raw, buffer + 12, sizeof(gid.raw)); - mtu = static_cast(ntohl(intBuffer[7])); - } + /// @brief Deserialize QP info from buffer + void deserialize(const char* buffer) { + const uint32_t* intBuffer = reinterpret_cast(buffer); + lid = ntohl(intBuffer[0]); + qpn = ntohl(intBuffer[1]); + psn = ntohl(intBuffer[2]); + memcpy(gid.raw, buffer + 12, sizeof(gid.raw)); + mtu = static_cast(ntohl(intBuffer[7])); + } - static const size_t size = 12 + sizeof(gid.raw) + 4; + static const size_t size = 12 + sizeof(gid.raw) + 4; }; /// @brief RDMA connection context struct Connection { - std::atomic connected; + std::atomic connected; - // Memory regions - struct ibv_mr *recv_mr; - struct ibv_mr *send_mr; + // Memory regions + struct ibv_mr* recv_mr; + struct ibv_mr* send_mr; - // Cache pointers - std::vector> local_cache_key_ptr_per_layer; - std::vector> local_cache_value_ptr_per_layer; + // Cache pointers + std::vector> local_cache_key_ptr_per_layer; + std::vector> local_cache_value_ptr_per_layer; - // Memory region lists - std::vector write_cache_key_server_mr_list; - std::vector write_cache_value_server_mr_list; - std::vector> write_mr_key_list; - std::vector> write_mr_value_list; + // Memory region lists + std::vector write_cache_key_server_mr_list; + std::vector write_cache_value_server_mr_list; + std::vector> write_mr_key_list; + std::vector> write_mr_value_list; - // Remote access information - std::vector write_cache_key_remote_ptr_list; - std::vector write_cache_key_remote_rkey_list; - std::vector write_cache_value_remote_ptr_list; - std::vector write_cache_value_remote_rkey_list; + // Remote access information + std::vector write_cache_key_remote_ptr_list; + std::vector write_cache_key_remote_rkey_list; + std::vector write_cache_value_remote_ptr_list; + std::vector write_cache_value_remote_rkey_list; - // Received remote memory information - std::vector receive_write_cache_key_remote_ptr_list; - std::vector receive_write_cache_key_remote_rkey_list; - std::vector receive_write_cache_value_remote_ptr_list; - std::vector receive_write_cache_value_remote_rkey_list; + // Received remote memory information + std::vector receive_write_cache_key_remote_ptr_list; + std::vector receive_write_cache_key_remote_rkey_list; + std::vector receive_write_cache_value_remote_ptr_list; + std::vector receive_write_cache_value_remote_rkey_list; - std::vector send_write_cache_key_remote_ptr_list; - std::vector send_write_cache_key_remote_rkey_list; - std::vector send_write_cache_value_remote_ptr_list; - std::vector send_write_cache_value_remote_rkey_list; + std::vector send_write_cache_key_remote_ptr_list; + std::vector send_write_cache_key_remote_rkey_list; + std::vector send_write_cache_value_remote_ptr_list; + std::vector send_write_cache_value_remote_rkey_list; - // For rdma read operations - std::vector read_bufs; - std::vector read_mrs; + // For rdma read operations + std::vector read_bufs; + std::vector read_mrs; - // Work completion tracking - int wc_count; - int wc_target_count; + // Work completion tracking + int wc_count; + int wc_target_count; - // Configuration - int layer_number; - int block_number; - int block_byte_size; - std::string url; + // Configuration + int layer_number; + int block_number; + int block_byte_size; + std::string url; - Connection() = default; - ~Connection(); + Connection() = default; + ~Connection(); }; /// @brief RDMA context structure struct RdmaContext { - int sock_fd; - struct ibv_context* context; - struct ibv_comp_channel* channel; - struct ibv_pd* pd; - struct ibv_mr* mr; - struct ibv_cq* cq; - struct ibv_qp* qp; - struct ibv_port_attr portinfo; - struct Connection conn; + int sock_fd; + struct ibv_context* context; + struct ibv_comp_channel* channel; + struct ibv_pd* pd; + struct ibv_mr* mr; + struct ibv_cq* cq; + struct ibv_qp* qp; + struct ibv_port_attr portinfo; + struct Connection conn; }; // Global variables @@ -176,36 +169,46 @@ extern std::vector g_ib_all_devs; static int g_kvcache_ib_dev_nums = -1; // Connection management functions -bool client_exchange_destinations( - struct RdmaContext* ctx, - int ib_port, - unsigned int port, - int gidx, - const std::string& dst_ip); +bool client_exchange_destinations(struct RdmaContext* ctx, + int ib_port, + unsigned int port, + int gidx, + const std::string& dst_ip); int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest); -struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd); +struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, + struct ibv_pd** g_pd); bool clear_qp_info(struct RdmaContext* ctx); // QP modification functions -QpStatus modify_qp_to_rts(struct RdmaContext* ctx, int port, int my_psn, - struct QpInfo* dest, int sgid_id); -bool poll_cq_with_timeout(struct RdmaContext* ctx, int timeout_seconds, int cqe_count); +QpStatus modify_qp_to_rts(struct RdmaContext* ctx, + int port, + int my_psn, + struct QpInfo* dest, + int sgid_id); +bool poll_cq_with_timeout(struct RdmaContext* ctx, + int timeout_seconds, + int cqe_count); // Utility functions -int get_port_info(struct ibv_context* Context, int port, - struct ibv_port_attr* attr); +int get_port_info(struct ibv_context* Context, + int port, + struct ibv_port_attr* attr); int parse_port_ib_info(); // Memory region exchange bool client_exchange_mr(struct RdmaContext* ctx); bool server_exchange_mr(struct RdmaContext* ctx); -bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num); -bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num); +bool server_send_memory_region(struct RdmaContext* ctx, + void* local_mr, + int byte_num); +bool client_receive_memory_region(struct RdmaContext* ctx, + void* remote_mr, + int byte_num); // Network setup int setup_listening_socket(int port); int configure_epoll(int sockfd); std::vector get_net_ifname(); -#endif // FASTDEPLOY_KVCACHE_CONNECTION_H +#endif // FASTDEPLOY_KVCACHE_CONNECTION_H diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h index de759e909..e0251f8d4 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h @@ -4,77 +4,88 @@ #pragma once #include -#include -#include #include #include -#include "util.h" // Contains constant definitions +#include +#include #include "kvcache_connection.h" #include "log.h" - +#include "util.h" // Contains constant definitions /** * @brief RDMA communication handler for key-value cache */ class RDMACommunicator { -public: - // Construction/Destruction - RDMACommunicator(std::string &role, int gpu_idx, std::string &port, - std::vector local_key_cache, - std::vector local_value_cache, - int block_number, int block_bytes); - ~RDMACommunicator(); + public: + // Construction/Destruction + RDMACommunicator(std::string& role, + int gpu_idx, + std::string& port, + std::vector local_key_cache, + std::vector local_value_cache, + int block_number, + int block_bytes); + ~RDMACommunicator(); - // Connection management - int connect(const std::string &dst_ip, const std::string &dst_port); - bool is_connected(const std::string &dst_ip, const std::string &dst_port); + // Connection management + int connect(const std::string& dst_ip, const std::string& dst_port); + bool is_connected(const std::string& dst_ip, const std::string& dst_port); - // Core functionality - int write_cache(const std::string &ip, const std::string &port, - const std::vector& local_block_ids, - const std::vector& remote_block_ids, - int32_t layer_idx); + // Core functionality + int write_cache(const std::string& ip, + const std::string& port, + const std::vector& local_block_ids, + const std::vector& remote_block_ids, + int32_t layer_idx); - // Server Init - int init_server(); + // Server Init + int init_server(); - // get socket nic ip - std::string fetch_local_ip(); + // get socket nic ip + std::string fetch_local_ip(); -private: - // Server Core functions - int start_server(int sport, int sgid_idx, int gpu_index); + private: + // Server Core functions + int start_server(int sport, int sgid_idx, int gpu_index); - // Internal implementation methods - void resize_vectors(); - void assign_pointers(); - void validate_addr(); - bool client_mr_register_per_layer(struct RdmaContext *ctx); - bool server_mr_register_per_layer(struct RdmaContext *ctx); - struct ibv_mr* register_memory_region(ibv_pd* pd, void* addr, size_t size, - const std::string& desc, uint32_t access_flags); - bool deregister_memory_regions(struct RdmaContext* ctx); + // Internal implementation methods + void resize_vectors(); + void assign_pointers(); + void validate_addr(); + bool client_mr_register_per_layer(struct RdmaContext* ctx); + bool server_mr_register_per_layer(struct RdmaContext* ctx); + struct ibv_mr* register_memory_region(ibv_pd* pd, + void* addr, + size_t size, + const std::string& desc, + uint32_t access_flags); + bool deregister_memory_regions(struct RdmaContext* ctx); - bool post_block_send(struct RdmaContext* ctx, int layer_idx, - const std::vector& local_block_ids, - bool is_key, std::vector& remote_addr, - uint32_t rkey, const std::string &ip, - const std::string &port); + bool post_block_send(struct RdmaContext* ctx, + int layer_idx, + const std::vector& local_block_ids, + bool is_key, + std::vector& remote_addr, + uint32_t rkey, + const std::string& ip, + const std::string& port); - bool execute_rdma_writes(struct RdmaContext* ctx, int layer_idx, + bool execute_rdma_writes(struct RdmaContext* ctx, + int layer_idx, const std::vector& local_block_ids, - bool is_key, std::vector& remote_addr, + bool is_key, + std::vector& remote_addr, uint32_t rkey); - void prepare_write_requests(struct ibv_sge* sge_list, - struct ibv_send_wr* send_wr_list, - int layer_idx, - const std::vector& local_block_ids, - bool is_key, - std::vector& remote_addr, - uint32_t rkey); + void prepare_write_requests(struct ibv_sge* sge_list, + struct ibv_send_wr* send_wr_list, + int layer_idx, + const std::vector& local_block_ids, + bool is_key, + std::vector& remote_addr, + uint32_t rkey); - bool execute_read_verification(struct RdmaContext* ctx, + bool execute_read_verification(struct RdmaContext* ctx, size_t block_idx, uint64_t remote_addr, uint32_t rkey, @@ -82,46 +93,56 @@ private: const std::string& ip, const std::string& port); - bool post_send_with_retry(struct RdmaContext* ctx, + bool post_send_with_retry(struct RdmaContext* ctx, struct ibv_send_wr* wr_list, size_t inflight_wr, bool need_poll); - // Connection management - int client_listener(); - void close_server_connection(int fd, struct RdmaContext* ctx, int epollfd, - std::map& connectionContexts); - void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd); + // Connection management + int client_listener(); + void close_server_connection( + int fd, + struct RdmaContext* ctx, + int epollfd, + std::map& connectionContexts); + void close_client_connection(int fd, struct RdmaContext* ctx, int epollfd); - void remove_conn(const std::string& url); - struct RdmaContext *get_conn(const std::string &ip, - const std::string &port); + void remove_conn(const std::string& url); + struct RdmaContext* get_conn(const std::string& ip, const std::string& port); - // Member variables - std::string splitwise_role; // Role in distributed system ("decode" or other) - int gpu_idx; // GPU device index - std::string port; // Communication port - std::vector local_cache_key_ptr_layer_head_; // Key cache pointers - std::vector local_cache_value_ptr_layer_head_; // Value cache pointers - int block_number; // Number of blocks - int block_size_byte; // Size of each block in bytes - int layer_number; // Number of layers + // Member variables + std::string splitwise_role; // Role in distributed system ("decode" or other) + int gpu_idx; // GPU device index + std::string port; // Communication port + std::vector local_cache_key_ptr_layer_head_; // Key cache pointers + std::vector + local_cache_value_ptr_layer_head_; // Value cache pointers + int block_number; // Number of blocks + int block_size_byte; // Size of each block in bytes + int layer_number; // Number of layers - std::vector> local_cache_key_ptr_per_layer; // Per-layer key pointers - std::vector> local_cache_value_ptr_per_layer; // Per-layer value pointers + std::vector> + local_cache_key_ptr_per_layer; // Per-layer key pointers + std::vector> + local_cache_value_ptr_per_layer; // Per-layer value pointers - std::vector write_mr_key_list; // Memory regions for key writes - std::vector write_mr_value_list; // Memory regions for value writes - std::vector write_cache_key_server_mr_list; // Server-side key memory regions - std::vector write_cache_value_server_mr_list; // Server-side value memory regions + std::vector + write_mr_key_list; // Memory regions for key writes + std::vector + write_mr_value_list; // Memory regions for value writes + std::vector + write_cache_key_server_mr_list; // Server-side key memory regions + std::vector + write_cache_value_server_mr_list; // Server-side value memory regions - std::vector main_ip_list; // List of local IP addresses - std::map conn_map; // Active connections map - std::mutex mutex_; // Thread synchronization mutex - int rdma_event_channel_epoll_fd; // Epoll file descriptor - struct ibv_pd *g_pd = NULL; // fd - int RDMACommunicator_status; // Communicator status flag - bool start_client_listener = false; // Client listener flag + std::vector main_ip_list; // List of local IP addresses + std::map + conn_map; // Active connections map + std::mutex mutex_; // Thread synchronization mutex + int rdma_event_channel_epoll_fd; // Epoll file descriptor + struct ibv_pd* g_pd = NULL; // fd + int RDMACommunicator_status; // Communicator status flag + bool start_client_listener = false; // Client listener flag }; -#endif // KVCACHE_RDMA_H +#endif // KVCACHE_RDMA_H diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/log.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/log.h index d0bf18ae2..e68d9ce58 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/log.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/log.h @@ -19,99 +19,130 @@ * limitations under the License. */ +#include #include #include -#include -#include -#include //for gethostname #include -#include -#include -#include +#include +#include +#include //for gethostname #include +#include +#include #define KV_IS_DEBUG_ENABLED (std::getenv("KVCACHE_DEBUG")) -#define FILE_NAME(x) (strrchr(x,'/') ? strrchr(x,'/')+1 : x) +#define FILE_NAME(x) (strrchr(x, '/') ? strrchr(x, '/') + 1 : x) static thread_local char __attribute__((__unused__)) str[64]; // for log levels (C++ enum class style in C) typedef enum { - KV_LOG_LEVEL_INFO = 0, - KV_LOG_LEVEL_DEBUG = 1, - KV_LOG_LEVEL_WARN = 2, - KV_LOG_LEVEL_ERROR = 3 + KV_LOG_LEVEL_INFO = 0, + KV_LOG_LEVEL_DEBUG = 1, + KV_LOG_LEVEL_WARN = 2, + KV_LOG_LEVEL_ERROR = 3 } KVLogLevel; -void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, - int line, const char *fmt, ...) __attribute__ ((format (printf, 5, 6))); +void debug_log(KVLogLevel level, + bool enable_to_terminal, + const char *filefunc, + int line, + const char *fmt, + ...) __attribute__((format(printf, 5, 6))); /** - * @brief Unified logging macro to reduce duplication and improve maintainability. + * @brief Unified logging macro to reduce duplication and improve + * maintainability. * * @param level Log level (e.g., INFO, DEBUG, WARN, ERR). * @param to_terminal If true, the log will be printed to terminal. * @param ... Format string and arguments (like printf). */ #define KV_LOG(level, to_terminal, ...) \ - debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__) + debug_log(level, to_terminal, FILE_NAME(__FILE__), __LINE__, __VA_ARGS__) // Public logging macros with terminal output -#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__) -#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__) -#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__) -#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__) +#define WARN(...) KV_LOG(KV_LOG_LEVEL_WARN, true, __VA_ARGS__) +#define ERR(...) KV_LOG(KV_LOG_LEVEL_ERROR, true, __VA_ARGS__) +#define DEBUG(...) KV_LOG(KV_LOG_LEVEL_DEBUG, true, __VA_ARGS__) +#define INFO(...) KV_LOG(KV_LOG_LEVEL_INFO, true, __VA_ARGS__) #define gettid() ((pid_t)syscall(SYS_gettid)) -#define GET_CURRENT_TIME() do { \ - time_t timer = time(0); \ - struct tm* t = localtime(&timer); \ - char hostname[32]; \ - gethostname(hostname, 32); \ - sprintf(str, "%02d:%02d:%02d][%.32s][%d", \ - t->tm_hour, t->tm_min, t->tm_sec, hostname, gettid()); \ - } while (0) +#define GET_CURRENT_TIME() \ + do { \ + time_t timer = time(0); \ + struct tm *t = localtime(&timer); \ + char hostname[32]; \ + gethostname(hostname, 32); \ + sprintf(str, \ + "%02d:%02d:%02d][%.32s][%d", \ + t->tm_hour, \ + t->tm_min, \ + t->tm_sec, \ + hostname, \ + gettid()); \ + } while (0) -#define LOGE(fmt, arg...) do { \ - GET_CURRENT_TIME(); \ - fprintf(stderr, "[%s][ERR][KV_CACHE][%s:%d] " \ - fmt "\n",str, \ - FILE_NAME(__FILE__), __LINE__, ## arg); \ - } while (0) +#define LOGE(fmt, arg...) \ + do { \ + GET_CURRENT_TIME(); \ + fprintf(stderr, \ + "[%s][ERR][KV_CACHE][%s:%d] " fmt "\n", \ + str, \ + FILE_NAME(__FILE__), \ + __LINE__, \ + ##arg); \ + } while (0) -#define LOGW(fmt, arg...) do { \ - GET_CURRENT_TIME(); \ - fprintf(stderr, "[%s][WARN][KV_CACHE][%s:%d] " \ - fmt "\n",str, \ - FILE_NAME(__FILE__), __LINE__, ## arg); \ - } while (0) +#define LOGW(fmt, arg...) \ + do { \ + GET_CURRENT_TIME(); \ + fprintf(stderr, \ + "[%s][WARN][KV_CACHE][%s:%d] " fmt "\n", \ + str, \ + FILE_NAME(__FILE__), \ + __LINE__, \ + ##arg); \ + } while (0) -#define LOGI(fmt, arg...) do { \ - GET_CURRENT_TIME(); \ - fprintf(stdout, "[%s][INFO][KV_CACHE][%s:%d] " \ - fmt "\n",str, \ - FILE_NAME(__FILE__), __LINE__, ## arg); \ - } while (0) +#define LOGI(fmt, arg...) \ + do { \ + GET_CURRENT_TIME(); \ + fprintf(stdout, \ + "[%s][INFO][KV_CACHE][%s:%d] " fmt "\n", \ + str, \ + FILE_NAME(__FILE__), \ + __LINE__, \ + ##arg); \ + } while (0) -#define LOGD(fmt, arg...) do { \ - if (KV_IS_DEBUG_ENABLED) { \ - GET_CURRENT_TIME(); \ - fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \ - fmt "\n", str, \ - FILE_NAME(__FILE__), __LINE__, ## arg); \ - } \ - } while (0) +#define LOGD(fmt, arg...) \ + do { \ + if (KV_IS_DEBUG_ENABLED) { \ + GET_CURRENT_TIME(); \ + fprintf(stdout, \ + "[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \ + str, \ + FILE_NAME(__FILE__), \ + __LINE__, \ + ##arg); \ + } \ + } while (0) -#define LOGD_IF(cond, fmt, ...) do { \ - if ((cond)) \ - LOGD(fmt, __VA_ARGS__); \ - } while (0) +#define LOGD_IF(cond, fmt, ...) \ + do { \ + if ((cond)) LOGD(fmt, __VA_ARGS__); \ + } while (0) -#define LOGD_RAW(fmt, arg...) do { \ - if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \ - GET_CURRENT_TIME(); \ - fprintf(stdout, "[%s][DBG][KV_CACHE][%s:%d] " \ - fmt "\n", str, \ - FILE_NAME(__FILE__), __LINE__, ## arg); \ - } \ - } while (0) +#define LOGD_RAW(fmt, arg...) \ + do { \ + if (ENV_ENABLE_RAW("KV_IS_DEBUG_ENABLED")) { \ + GET_CURRENT_TIME(); \ + fprintf(stdout, \ + "[%s][DBG][KV_CACHE][%s:%d] " fmt "\n", \ + str, \ + FILE_NAME(__FILE__), \ + __LINE__, \ + ##arg); \ + } \ + } while (0) diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/util.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/util.h index c040b2a62..9dd8ebd99 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/util.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/util.h @@ -1,21 +1,21 @@ #ifndef KVCACHE_UTILS_H #define KVCACHE_UTILS_H -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include #include #include -#include -#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include #include "log.h" #define PATH_MAX 4096 /* # chars in a path name including nul */ @@ -28,22 +28,22 @@ /// @brief Connection status enumeration enum class ConnStatus { - kConnected, // Connection is active - kDisconnected, // Connection is not active - kError, // Connection error occurred - kTimeout, // Connection timed out - kInvalidParameters // Invalid connection parameters + kConnected, // Connection is active + kDisconnected, // Connection is not active + kError, // Connection error occurred + kTimeout, // Connection timed out + kInvalidParameters // Invalid connection parameters }; /// @brief Queue Pair (QP) setup result status enum class QpStatus { - kSuccess, // Successfully transitioned QP to RTS - kInvalidParameters, // ctx or dest is null - kDeviceQueryFailed, // ibv_query_device failed - kPortQueryFailed, // ibv_query_port failed - kMtuMismatch, // Requested MTU exceeds active MTU - kModifyToRTRFailed, // Failed to modify QP to RTR - kModifyToRTSFailed // Failed to modify QP to RTS + kSuccess, // Successfully transitioned QP to RTS + kInvalidParameters, // ctx or dest is null + kDeviceQueryFailed, // ibv_query_device failed + kPortQueryFailed, // ibv_query_port failed + kMtuMismatch, // Requested MTU exceeds active MTU + kModifyToRTRFailed, // Failed to modify QP to RTR + kModifyToRTSFailed // Failed to modify QP to RTS }; /** @@ -51,265 +51,281 @@ enum class QpStatus { * @param busId PCI bus ID string (e.g. "0000:3b:00.0") * @param[out] id Converted numeric ID */ -inline void busid_to_int64(const char *busId, int64_t *id) { - char hexStr[17] = {0}; - int hexOffset = 0; +inline void busid_to_int64(const char* busId, int64_t* id) { + char hexStr[17] = {0}; + int hexOffset = 0; - // Filter valid hex characters - for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) { - char c = busId[i]; - if (c == '.' || c == ':') continue; + // Filter valid hex characters + for (int i = 0; hexOffset < sizeof(hexStr) - 1 && busId[i] != '\0'; i++) { + char c = busId[i]; + if (c == '.' || c == ':') continue; - if ((c >= '0' && c <= '9') || - (c >= 'A' && c <= 'F') || - (c >= 'a' && c <= 'f')) { - hexStr[hexOffset++] = c; - } + if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f')) { + hexStr[hexOffset++] = c; } + } - *id = strtol(hexStr, NULL, 16); + *id = strtol(hexStr, NULL, 16); } class NetworkInterfaceManager { -public: - struct InterfaceInfo { - std::string name; - std::string ip; - bool is_up; - bool is_running; - bool is_loopback; + public: + struct InterfaceInfo { + std::string name; + std::string ip; + bool is_up; + bool is_running; + bool is_loopback; - bool isUsable() const { - return is_up && is_running && !is_loopback; - } - }; + bool isUsable() const { return is_up && is_running && !is_loopback; } + }; - static std::vector getAllInterfaces() { - std::vector interfaces; - struct ifaddrs *ifaddrs_ptr = nullptr; + static std::vector getAllInterfaces() { + std::vector interfaces; + struct ifaddrs* ifaddrs_ptr = nullptr; - if (getifaddrs(&ifaddrs_ptr) == -1) { - return interfaces; - } - - for (struct ifaddrs *ifa = ifaddrs_ptr; ifa != nullptr; ifa = ifa->ifa_next) { - if (ifa->ifa_addr == nullptr) continue; - if (ifa->ifa_addr->sa_family != AF_INET) continue; - - InterfaceInfo info; - info.name = ifa->ifa_name; - info.is_up = (ifa->ifa_flags & IFF_UP) != 0; - info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0; - info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0; - - struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr; - char ip_str[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN); - info.ip = ip_str; - - interfaces.push_back(info); - } - - freeifaddrs(ifaddrs_ptr); - return interfaces; + if (getifaddrs(&ifaddrs_ptr) == -1) { + return interfaces; } - static std::string getFirstUsableInterface() { - auto interfaces = getAllInterfaces(); + for (struct ifaddrs* ifa = ifaddrs_ptr; ifa != nullptr; + ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) continue; + if (ifa->ifa_addr->sa_family != AF_INET) continue; - for (const auto& iface : interfaces) { - if (iface.isUsable()) { - return iface.name; - } - } - return ""; + InterfaceInfo info; + info.name = ifa->ifa_name; + info.is_up = (ifa->ifa_flags & IFF_UP) != 0; + info.is_running = (ifa->ifa_flags & IFF_RUNNING) != 0; + info.is_loopback = (ifa->ifa_flags & IFF_LOOPBACK) != 0; + + struct sockaddr_in* sa = (struct sockaddr_in*)ifa->ifa_addr; + char ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &sa->sin_addr, ip_str, INET_ADDRSTRLEN); + info.ip = ip_str; + + interfaces.push_back(info); } - static void displayAllInterfaces() { - auto interfaces = getAllInterfaces(); + freeifaddrs(ifaddrs_ptr); + return interfaces; + } - printf("Available network interfaces:\n"); - for (const auto& iface : interfaces) { - printf(" %s: %s [%s%s%s]\n", - iface.name.c_str(), - iface.ip.c_str(), - iface.is_up ? "UP" : "DOWN", - iface.is_running ? ",RUNNING" : "", - iface.is_loopback ? ",LOOPBACK" : ""); - } + static std::string getFirstUsableInterface() { + auto interfaces = getAllInterfaces(); + + for (const auto& iface : interfaces) { + if (iface.isUsable()) { + return iface.name; + } } + return ""; + } + + static void displayAllInterfaces() { + auto interfaces = getAllInterfaces(); + + printf("Available network interfaces:\n"); + for (const auto& iface : interfaces) { + printf(" %s: %s [%s%s%s]\n", + iface.name.c_str(), + iface.ip.c_str(), + iface.is_up ? "UP" : "DOWN", + iface.is_running ? ",RUNNING" : "", + iface.is_loopback ? ",LOOPBACK" : ""); + } + } }; class KVCacheConfig { -private: - // Configuration values - int rdma_gid_index_; - bool has_rdma_dest_port_override_; // 替代 std::optional - int rdma_dest_port_override_; - const char* socket_interface_; - char* socket_interface_buffer_; - bool gdrcopy_flush_enabled_; - bool verify_read_enabled_; - bool debug_mode_enabled_; - bool debug_output_enabled_; - const char* debug_file_path_; - const char* error_file_path_; - bool relax_ordering_enabled_; - int ib_timeout_; - const char* rdma_nics_; + private: + // Configuration values + int rdma_gid_index_; + bool has_rdma_dest_port_override_; // 替代 std::optional + int rdma_dest_port_override_; + const char* socket_interface_; + char* socket_interface_buffer_; + bool gdrcopy_flush_enabled_; + bool verify_read_enabled_; + bool debug_mode_enabled_; + bool debug_output_enabled_; + const char* debug_file_path_; + const char* error_file_path_; + bool relax_ordering_enabled_; + int ib_timeout_; + const char* rdma_nics_; - // Private constructor for singleton pattern - KVCacheConfig() { - // Initialize configuration from environment variables - rdma_gid_index_ = parse_int_value( - std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX"); + // Private constructor for singleton pattern + KVCacheConfig() { + // Initialize configuration from environment variables + rdma_gid_index_ = parse_int_value( + std::getenv("KVCACHE_RDMA_GID_INDEX"), 3, "KVCACHE_RDMA_GID_INDEX"); - // Parse optional RDMA port override - const char* port_value = std::getenv("SET_RDMA_DEST_PORT"); - has_rdma_dest_port_override_ = false; // 默认为false - if (port_value) { - try { - rdma_dest_port_override_ = std::stoi(std::string(port_value)); - has_rdma_dest_port_override_ = true; - } catch (const std::exception& e) { - fprintf(stderr, "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", port_value); - } - } - - const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME"); - - if (env_interface && env_interface[0] != '\0') { - socket_interface_ = env_interface; - printf("Using specified interface: %s\n", socket_interface_); - } else { - std::string iface = NetworkInterfaceManager::getFirstUsableInterface(); - if (!iface.empty()) { - socket_interface_buffer_ = new char[iface.size() + 1]; - std::strcpy(socket_interface_buffer_, iface.c_str()); - socket_interface_ = socket_interface_buffer_; - printf("Auto-detected interface: %s\n", socket_interface_); - } else { - fprintf(stderr, "Warning: No usable network interface found\n"); - socket_interface_ = ""; - } - NetworkInterfaceManager::displayAllInterfaces(); - } - - socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME"); - debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE"); - error_file_path_ = std::getenv("KVCACHE_ERROR_FILE"); - - gdrcopy_flush_enabled_ = parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE")); - verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ")); - debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) || - parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED")); - debug_output_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT")); - - relax_ordering_enabled_ = parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING")); - - ib_timeout_ = parse_int_value( - std::getenv("KVCACHE_IB_TIMEOUT"), - 18, - "KVCACHE_IB_TIMEOUT" - ); - - rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS"); + // Parse optional RDMA port override + const char* port_value = std::getenv("SET_RDMA_DEST_PORT"); + has_rdma_dest_port_override_ = false; // 默认为false + if (port_value) { + try { + rdma_dest_port_override_ = std::stoi(std::string(port_value)); + has_rdma_dest_port_override_ = true; + } catch (const std::exception& e) { + fprintf(stderr, + "Invalid SET_RDMA_DEST_PORT value: '%s', ignoring\n", + port_value); + } } - // Helper methods - bool parse_bool_value(const char* value) { - if (!value) return false; + const char* env_interface = std::getenv("KVCACHE_SOCKET_IFNAME"); - std::string str_value(value); - std::transform(str_value.begin(), str_value.end(), str_value.begin(), ::tolower); - - return (str_value == "1" || str_value == "true" || - str_value == "on" || str_value == "yes"); + if (env_interface && env_interface[0] != '\0') { + socket_interface_ = env_interface; + printf("Using specified interface: %s\n", socket_interface_); + } else { + std::string iface = NetworkInterfaceManager::getFirstUsableInterface(); + if (!iface.empty()) { + socket_interface_buffer_ = new char[iface.size() + 1]; + std::strcpy(socket_interface_buffer_, iface.c_str()); + socket_interface_ = socket_interface_buffer_; + printf("Auto-detected interface: %s\n", socket_interface_); + } else { + fprintf(stderr, "Warning: No usable network interface found\n"); + socket_interface_ = ""; + } + NetworkInterfaceManager::displayAllInterfaces(); } - int parse_int_value(const char* value, int default_value, const char* env_name) { - if (!value) return default_value; + socket_interface_ = std::getenv("KVCACHE_SOCKET_IFNAME"); + debug_file_path_ = std::getenv("KVCACHE_DEBUG_FILE"); + error_file_path_ = std::getenv("KVCACHE_ERROR_FILE"); - try { - return std::stoi(std::string(value)); - } catch (const std::invalid_argument& e) { - fprintf(stderr, "Invalid value for %s: '%s', using default: %d\n", - env_name, value, default_value); - return default_value; - } catch (const std::out_of_range& e) { - fprintf(stderr, "%s value out of range: '%s', using default: %d\n", - env_name, value, default_value); - return default_value; - } + gdrcopy_flush_enabled_ = + parse_bool_value(std::getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE")); + verify_read_enabled_ = parse_bool_value(std::getenv("KVCACHE_VERIFY_READ")); + debug_mode_enabled_ = parse_bool_value(std::getenv("KVCACHE_DEBUG")) || + parse_bool_value(std::getenv("KV_IS_DEBUG_ENABLED")); + debug_output_enabled_ = + parse_bool_value(std::getenv("KVCACHE_DEBUG_OUTPUT")); + + relax_ordering_enabled_ = + parse_bool_value(std::getenv("KVCACHE_RELAX_ORDERING")); + + ib_timeout_ = parse_int_value( + std::getenv("KVCACHE_IB_TIMEOUT"), 18, "KVCACHE_IB_TIMEOUT"); + + rdma_nics_ = std::getenv("KVCACHE_RDMA_NICS"); + } + + // Helper methods + bool parse_bool_value(const char* value) { + if (!value) return false; + + std::string str_value(value); + std::transform( + str_value.begin(), str_value.end(), str_value.begin(), ::tolower); + + return (str_value == "1" || str_value == "true" || str_value == "on" || + str_value == "yes"); + } + + int parse_int_value(const char* value, + int default_value, + const char* env_name) { + if (!value) return default_value; + + try { + return std::stoi(std::string(value)); + } catch (const std::invalid_argument& e) { + fprintf(stderr, + "Invalid value for %s: '%s', using default: %d\n", + env_name, + value, + default_value); + return default_value; + } catch (const std::out_of_range& e) { + fprintf(stderr, + "%s value out of range: '%s', using default: %d\n", + env_name, + value, + default_value); + return default_value; + } + } + + public: + // Prevent copying and assignment + KVCacheConfig(const KVCacheConfig&) = delete; + KVCacheConfig& operator=(const KVCacheConfig&) = delete; + + // Get singleton instance + static KVCacheConfig& getInstance() { + static KVCacheConfig instance; + return instance; + } + + int get_ib_timeout() const { return ib_timeout_; } + + // Configuration retrieval methods + int get_rdma_gid_index() const { return rdma_gid_index_; } + + int resolve_rdma_dest_port(int default_port) const { + return has_rdma_dest_port_override_ ? rdma_dest_port_override_ + : default_port; + } + + int resolve_rdma_dest_port(const std::string& default_port) const { + try { + return resolve_rdma_dest_port(std::stoi(default_port)); + } catch (const std::exception& e) { + fprintf( + stderr, "Invalid default port string: %s\n", default_port.c_str()); + return 0; + } + } + + const char* get_socket_interface() const { return socket_interface_; } + const char* get_debug_file_path() const { return debug_file_path_; } + const char* get_error_file_path() const { return error_file_path_; } + const char* get_rdma_nics() const { return rdma_nics_; } + + // Feature check methods + bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; } + bool is_verify_read_enabled() const { return verify_read_enabled_; } + bool is_debug_mode_enabled() const { return debug_mode_enabled_; } + bool is_debug_output_enabled() const { return debug_output_enabled_; } + bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; } + + // Display configuration + void displayConfiguration() const { + INFO("KVCache Configuration:\n"); + INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_); + + if (has_rdma_dest_port_override_) { + INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", + rdma_dest_port_override_); } -public: - // Prevent copying and assignment - KVCacheConfig(const KVCacheConfig&) = delete; - KVCacheConfig& operator=(const KVCacheConfig&) = delete; - - // Get singleton instance - static KVCacheConfig& getInstance() { - static KVCacheConfig instance; - return instance; + if (socket_interface_) { + INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_); } - int get_ib_timeout() const { return ib_timeout_; } + INFO("Init KVCacheConfig GDRCopy Flush: %s\n", + gdrcopy_flush_enabled_ ? "enabled" : "disabled"); + INFO("Init KVCacheConfig Verify Read: %s\n", + verify_read_enabled_ ? "enabled" : "disabled"); + INFO("Init KVCacheConfig Debug Mode: %s\n", + debug_mode_enabled_ ? "enabled" : "disabled"); + INFO("Init KVCacheConfig Debug Output: %s\n", + debug_output_enabled_ ? "enabled" : "disabled"); - // Configuration retrieval methods - int get_rdma_gid_index() const { return rdma_gid_index_; } - - int resolve_rdma_dest_port(int default_port) const { - return has_rdma_dest_port_override_ ? rdma_dest_port_override_ : default_port; + if (debug_file_path_) { + INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_); } - int resolve_rdma_dest_port(const std::string& default_port) const { - try { - return resolve_rdma_dest_port(std::stoi(default_port)); - } catch (const std::exception& e) { - fprintf(stderr, "Invalid default port string: %s\n", default_port.c_str()); - return 0; - } - } - - const char* get_socket_interface() const { return socket_interface_; } - const char* get_debug_file_path() const { return debug_file_path_; } - const char* get_error_file_path() const { return error_file_path_; } - const char* get_rdma_nics() const { return rdma_nics_; } - - // Feature check methods - bool is_gdrcopy_flush_enabled() const { return gdrcopy_flush_enabled_; } - bool is_verify_read_enabled() const { return verify_read_enabled_; } - bool is_debug_mode_enabled() const { return debug_mode_enabled_; } - bool is_debug_output_enabled() const { return debug_output_enabled_; } - bool is_relax_ordering_enabled() const { return relax_ordering_enabled_; } - - // Display configuration - void displayConfiguration() const { - INFO("KVCache Configuration:\n"); - INFO("Init KVCacheConfig RDMA GID Index: %d\n", rdma_gid_index_); - - if (has_rdma_dest_port_override_) { - INFO("Init KVCacheConfig RDMA Destination Port Override: %d\n", rdma_dest_port_override_); - } - - if (socket_interface_) { - INFO("Init KVCacheConfig Socket Interface: %s\n", socket_interface_); - } - - INFO("Init KVCacheConfig GDRCopy Flush: %s\n", gdrcopy_flush_enabled_ ? "enabled" : "disabled"); - INFO("Init KVCacheConfig Verify Read: %s\n", verify_read_enabled_ ? "enabled" : "disabled"); - INFO("Init KVCacheConfig Debug Mode: %s\n", debug_mode_enabled_ ? "enabled" : "disabled"); - INFO("Init KVCacheConfig Debug Output: %s\n", debug_output_enabled_ ? "enabled" : "disabled"); - - if (debug_file_path_) { - INFO("Init KVCacheConfig Debug File: %s\n", debug_file_path_); - } - - if (error_file_path_) { - INFO("Init KVCacheConfig Error File: %s\n", error_file_path_); - } + if (error_file_path_) { + INFO("Init KVCacheConfig Error File: %s\n", error_file_path_); } + } }; #endif diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp index 6bb4e43a9..8e9ec468e 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_connection.cpp @@ -30,34 +30,34 @@ std::vector g_ib_all_devs; * @return PCI bus ID as int64_t, -1 on error */ static int64_t get_ib_busid(const char *dev_name) { - char dev_path[PATH_MAX]; - snprintf(dev_path, PATH_MAX, "/sys/class/infiniband/%s/device", dev_name); + char dev_path[PATH_MAX]; + snprintf(dev_path, PATH_MAX, "/sys/class/infiniband/%s/device", dev_name); - char *p = realpath(dev_path, NULL); - if (p == NULL) { - WARN("Failed to get realpath for device %s: %s", dev_name, strerror(errno)); - return -1; - } + char *p = realpath(dev_path, NULL); + if (p == NULL) { + WARN("Failed to get realpath for device %s: %s", dev_name, strerror(errno)); + return -1; + } - // Extract bus ID from path - int offset = strlen(p) - 1; - while (offset >= 0 && p[offset] != '/') { - offset--; - } + // Extract bus ID from path + int offset = strlen(p) - 1; + while (offset >= 0 && p[offset] != '/') { + offset--; + } - if (offset < 0) { - free(p); - return -1; - } - - char bus_str[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; - strncpy(bus_str, p + offset + 1, sizeof(bus_str) - 1); - bus_str[sizeof(bus_str) - 1] = '\0'; + if (offset < 0) { free(p); + return -1; + } - int64_t ret; - busid_to_int64(bus_str, &ret); - return ret; + char bus_str[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + strncpy(bus_str, p + offset + 1, sizeof(bus_str) - 1); + bus_str[sizeof(bus_str) - 1] = '\0'; + free(p); + + int64_t ret; + busid_to_int64(bus_str, &ret); + return ret; } /** @@ -67,506 +67,538 @@ static int64_t get_ib_busid(const char *dev_name) { * @note This function is thread-safe and will only parse once */ int parse_port_ib_info() { - if (g_kvcache_ib_dev_nums != -1) return 0; + if (g_kvcache_ib_dev_nums != -1) return 0; - pthread_mutex_lock(&g_ib_lock); - if (g_kvcache_ib_dev_nums != -1) { - pthread_mutex_unlock(&g_ib_lock); - return 0; - } - - INFO("Initializing IB device information"); - g_kvcache_ib_dev_nums = 0; - - const char* env_nics = KVCacheConfig::getInstance().get_rdma_nics(); - if (!env_nics) { - ERR("Environment variable KVCACHE_RDMA_NICS not set"); - pthread_mutex_unlock(&g_ib_lock); - return -1; - } - - // Parse NIC list - char nic_names[MAXNAMESIZE][MAXNAMESIZE] = {0}; - int nic_count = 0; - char* env_copy = strdup(env_nics); - if (!env_copy) { - ERR("Failed to duplicate NIC list string"); - pthread_mutex_unlock(&g_ib_lock); - return -1; - } - - for (char* token = strtok(env_copy, ","); token && nic_count < MAXNAMESIZE; - token = strtok(NULL, ",")) { - strncpy(nic_names[nic_count++], token, MAXNAMESIZE - 1); - } - free(env_copy); - - // Get IB device list - int total_devs = 0; - ibv_device** dev_list = ibv_get_device_list(&total_devs); - if (!dev_list || total_devs <= 0) { - ERR("No IB devices found, ibv_get_device_list failed, total_devs = %d", total_devs); - pthread_mutex_unlock(&g_ib_lock); - return -1; - } - INFO("Found %d IB devices, filtering by NIC list", total_devs); - - for (int i = 0; i < total_devs && g_kvcache_ib_dev_nums < KVCACHE_RDMA_MAX_NICS; ++i) { - const char* dev_name = dev_list[i]->name; - - bool allowed = false; - for (int j = 0; j < nic_count; ++j) { - if (strcmp(dev_name, nic_names[j]) == 0) { - allowed = true; - break; - } - } - if (!allowed) { - WARN("Skipping device not in NIC list: %s", dev_name); - continue; - } - - ibv_context* ctx = ibv_open_device(dev_list[i]); - if (!ctx) { - ERR("Failed to open device %s: %s", dev_name, strerror(errno)); - continue; - } - - ibv_device_attr dev_attr = {}; - if (ibv_query_device(ctx, &dev_attr) != 0) { - ERR("Failed to query device %s: %s", dev_name, strerror(errno)); - ibv_close_device(ctx); - continue; - } - - int valid_ports = 0; - for (int port_num = 1; port_num <= dev_attr.phys_port_cnt; ++port_num) { - ibv_port_attr port_attr = {}; - if (ibv_query_port(ctx, port_num, &port_attr) != 0) { - WARN("Failed to query port %d on device %s: %s", port_num, dev_name, strerror(errno)); - continue; - } - - if (port_attr.state != IBV_PORT_ACTIVE) { - WARN("Port %d on device %s is not active (state: %d)", port_num, dev_name, port_attr.state); - continue; - } - - if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND && - port_attr.link_layer != IBV_LINK_LAYER_ETHERNET) { - WARN("Unsupported link layer %d on device %s port %d", port_attr.link_layer, dev_name, port_num); - continue; - } - - IbDeviceInfo dev_info = {}; - dev_info.device = i; - dev_info.guid = dev_attr.sys_image_guid; - dev_info.port = port_num; - dev_info.link = port_attr.link_layer; - dev_info.active_mtu = port_attr.active_mtu; - dev_info.context = ctx; - dev_info.busid = get_ib_busid(dev_name); - dev_info.maxQp = dev_attr.max_qp; - strncpy(dev_info.devName, dev_name, MAXNAMESIZE); - - INFO("Adding device %s port %d (%s)", dev_name, port_num, - port_attr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); - - g_ib_all_devs.push_back(dev_info); - ++g_kvcache_ib_dev_nums; - ++valid_ports; - } - - if (valid_ports == 0) { - ERR("No valid ports found for device %s", dev_name); - ibv_close_device(ctx); - } - } - - ibv_free_device_list(dev_list); - INFO("Initialized %d IB devices", g_kvcache_ib_dev_nums); + pthread_mutex_lock(&g_ib_lock); + if (g_kvcache_ib_dev_nums != -1) { pthread_mutex_unlock(&g_ib_lock); return 0; + } + + INFO("Initializing IB device information"); + g_kvcache_ib_dev_nums = 0; + + const char *env_nics = KVCacheConfig::getInstance().get_rdma_nics(); + if (!env_nics) { + ERR("Environment variable KVCACHE_RDMA_NICS not set"); + pthread_mutex_unlock(&g_ib_lock); + return -1; + } + + // Parse NIC list + char nic_names[MAXNAMESIZE][MAXNAMESIZE] = {0}; + int nic_count = 0; + char *env_copy = strdup(env_nics); + if (!env_copy) { + ERR("Failed to duplicate NIC list string"); + pthread_mutex_unlock(&g_ib_lock); + return -1; + } + + for (char *token = strtok(env_copy, ","); token && nic_count < MAXNAMESIZE; + token = strtok(NULL, ",")) { + strncpy(nic_names[nic_count++], token, MAXNAMESIZE - 1); + } + free(env_copy); + + // Get IB device list + int total_devs = 0; + ibv_device **dev_list = ibv_get_device_list(&total_devs); + if (!dev_list || total_devs <= 0) { + ERR("No IB devices found, ibv_get_device_list failed, total_devs = %d", + total_devs); + pthread_mutex_unlock(&g_ib_lock); + return -1; + } + INFO("Found %d IB devices, filtering by NIC list", total_devs); + + for (int i = 0; + i < total_devs && g_kvcache_ib_dev_nums < KVCACHE_RDMA_MAX_NICS; + ++i) { + const char *dev_name = dev_list[i]->name; + + bool allowed = false; + for (int j = 0; j < nic_count; ++j) { + if (strcmp(dev_name, nic_names[j]) == 0) { + allowed = true; + break; + } + } + if (!allowed) { + WARN("Skipping device not in NIC list: %s", dev_name); + continue; + } + + ibv_context *ctx = ibv_open_device(dev_list[i]); + if (!ctx) { + ERR("Failed to open device %s: %s", dev_name, strerror(errno)); + continue; + } + + ibv_device_attr dev_attr = {}; + if (ibv_query_device(ctx, &dev_attr) != 0) { + ERR("Failed to query device %s: %s", dev_name, strerror(errno)); + ibv_close_device(ctx); + continue; + } + + int valid_ports = 0; + for (int port_num = 1; port_num <= dev_attr.phys_port_cnt; ++port_num) { + ibv_port_attr port_attr = {}; + if (ibv_query_port(ctx, port_num, &port_attr) != 0) { + WARN("Failed to query port %d on device %s: %s", + port_num, + dev_name, + strerror(errno)); + continue; + } + + if (port_attr.state != IBV_PORT_ACTIVE) { + WARN("Port %d on device %s is not active (state: %d)", + port_num, + dev_name, + port_attr.state); + continue; + } + + if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND && + port_attr.link_layer != IBV_LINK_LAYER_ETHERNET) { + WARN("Unsupported link layer %d on device %s port %d", + port_attr.link_layer, + dev_name, + port_num); + continue; + } + + IbDeviceInfo dev_info = {}; + dev_info.device = i; + dev_info.guid = dev_attr.sys_image_guid; + dev_info.port = port_num; + dev_info.link = port_attr.link_layer; + dev_info.active_mtu = port_attr.active_mtu; + dev_info.context = ctx; + dev_info.busid = get_ib_busid(dev_name); + dev_info.maxQp = dev_attr.max_qp; + strncpy(dev_info.devName, dev_name, MAXNAMESIZE); + + INFO("Adding device %s port %d (%s)", + dev_name, + port_num, + port_attr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); + + g_ib_all_devs.push_back(dev_info); + ++g_kvcache_ib_dev_nums; + ++valid_ports; + } + + if (valid_ports == 0) { + ERR("No valid ports found for device %s", dev_name); + ibv_close_device(ctx); + } + } + + ibv_free_device_list(dev_list); + INFO("Initialized %d IB devices", g_kvcache_ib_dev_nums); + pthread_mutex_unlock(&g_ib_lock); + return 0; } static int modify_qp_to_init(struct ibv_qp *qp, struct ibv_qp_attr *attr) { - int ret = ibv_modify_qp( - qp, - attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS); - if (ret != 0) { - ERR("Failed to modify QP to INIT: %s (errno=%d)", strerror(errno), errno); - } + int ret = ibv_modify_qp( + qp, + attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS); + if (ret != 0) { + ERR("Failed to modify QP to INIT: %s (errno=%d)", strerror(errno), errno); + } - return ret; + return ret; } static int modify_qp_to_rtr(struct ibv_qp *qp, struct ibv_qp_attr *attr) { - int ret = ibv_modify_qp( - qp, - attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); - if (ret != 0) { - ERR("Failed to modify QP to RTR: %s (errno=%d)", strerror(errno), errno); - } + int ret = ibv_modify_qp(qp, + attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + if (ret != 0) { + ERR("Failed to modify QP to RTR: %s (errno=%d)", strerror(errno), errno); + } - return ret; + return ret; } static int modify_rtr_to_rts(struct ibv_qp *qp, struct ibv_qp_attr *attr) { - int ret = ibv_modify_qp( - qp, - attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC); - if (ret != 0) { - ERR("Failed to modify QP to RTS: %s (errno=%d)", strerror(errno), errno); - } - return ret; + int ret = ibv_modify_qp(qp, + attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC); + if (ret != 0) { + ERR("Failed to modify QP to RTS: %s (errno=%d)", strerror(errno), errno); + } + return ret; } -int server_exchange_qp_info(int connfd, QpInfo* local_dest, QpInfo* rem_dest) { - if (!local_dest || !rem_dest) { - ERR("Null pointer passed to server_exchange_qp_info"); - return -1; - } +int server_exchange_qp_info(int connfd, QpInfo *local_dest, QpInfo *rem_dest) { + if (!local_dest || !rem_dest) { + ERR("Null pointer passed to server_exchange_qp_info"); + return -1; + } - char buffer[QpInfo::size]; - memset(buffer, 0, sizeof(buffer)); + char buffer[QpInfo::size]; + memset(buffer, 0, sizeof(buffer)); - // Read remote QP info from the connection - int n = read(connfd, buffer, QpInfo::size); - if (n != static_cast(QpInfo::size)) { - ERR("Failed to read remote QP info: read %d bytes, expected %zu", n, QpInfo::size); - return -1; - } + // Read remote QP info from the connection + int n = read(connfd, buffer, QpInfo::size); + if (n != static_cast(QpInfo::size)) { + ERR("Failed to read remote QP info: read %d bytes, expected %zu", + n, + QpInfo::size); + return -1; + } - QpInfo remote_msg; - remote_msg.deserialize(buffer); - *rem_dest = remote_msg; - rem_dest->psn = 0; + QpInfo remote_msg; + remote_msg.deserialize(buffer); + *rem_dest = remote_msg; + rem_dest->psn = 0; - // Prepare local QP info to send - QpInfo local_msg = *local_dest; - local_msg.psn = 0; - local_msg.serialize(buffer); + // Prepare local QP info to send + QpInfo local_msg = *local_dest; + local_msg.psn = 0; + local_msg.serialize(buffer); - // Send local QP info to the remote side - n = write(connfd, buffer, QpInfo::size); - if (n != static_cast(QpInfo::size)) { - ERR("Failed to send local QP info: wrote %d bytes", n); - return -1; - } + // Send local QP info to the remote side + n = write(connfd, buffer, QpInfo::size); + if (n != static_cast(QpInfo::size)) { + ERR("Failed to send local QP info: wrote %d bytes", n); + return -1; + } - return 0; + return 0; } -int get_port_info(struct ibv_context *Context, int port, struct ibv_port_attr *attr) { - return ibv_query_port(Context, port, attr); +int get_port_info(struct ibv_context *Context, + int port, + struct ibv_port_attr *attr) { + return ibv_query_port(Context, port, attr); } -QpStatus modify_qp_to_rts( - struct RdmaContext *ctx, - int port, - int my_psn, - struct QpInfo *dest, - int sgid_id) { - if (!ctx || !dest) { - ERR("Invalid input parameters: ctx or dest is NULL"); - return QpStatus::kInvalidParameters; - } +QpStatus modify_qp_to_rts(struct RdmaContext *ctx, + int port, + int my_psn, + struct QpInfo *dest, + int sgid_id) { + if (!ctx || !dest) { + ERR("Invalid input parameters: ctx or dest is NULL"); + return QpStatus::kInvalidParameters; + } - struct ibv_device_attr dev_attr; - if (ibv_query_device(ctx->context, &dev_attr)) { - ERR("Failed to query device attributes: %s (errno=%d)", strerror(errno), errno); - return QpStatus::kDeviceQueryFailed; - } + struct ibv_device_attr dev_attr; + if (ibv_query_device(ctx->context, &dev_attr)) { + ERR("Failed to query device attributes: %s (errno=%d)", + strerror(errno), + errno); + return QpStatus::kDeviceQueryFailed; + } - struct ibv_port_attr port_attr; - if (ibv_query_port(ctx->context, port, &port_attr)) { - ERR("Failed to query port attributes: %s (errno=%d)", strerror(errno), errno); - return QpStatus::kPortQueryFailed; - } + struct ibv_port_attr port_attr; + if (ibv_query_port(ctx->context, port, &port_attr)) { + ERR("Failed to query port attributes: %s (errno=%d)", + strerror(errno), + errno); + return QpStatus::kPortQueryFailed; + } - if (dest->mtu > port_attr.active_mtu) { - ERR("Specified MTU (%d) is greater than active port MTU (%d)", dest->mtu, port_attr.active_mtu); - return QpStatus::kMtuMismatch; - } + if (dest->mtu > port_attr.active_mtu) { + ERR("Specified MTU (%d) is greater than active port MTU (%d)", + dest->mtu, + port_attr.active_mtu); + return QpStatus::kMtuMismatch; + } - struct ibv_qp_attr attr; - memset(&attr, 0, sizeof(struct ibv_qp_attr)); + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(struct ibv_qp_attr)); - attr.qp_state = IBV_QPS_RTR; - attr.path_mtu = dest->mtu; - attr.dest_qp_num = dest->qpn; - attr.rq_psn = 0; - attr.max_dest_rd_atomic = 1; - attr.min_rnr_timer = 12; + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = dest->mtu; + attr.dest_qp_num = dest->qpn; + attr.rq_psn = 0; + attr.max_dest_rd_atomic = 1; + attr.min_rnr_timer = 12; - attr.ah_attr.is_global = 1; - attr.ah_attr.grh.hop_limit = 255; - attr.ah_attr.grh.flow_label = 0; - attr.ah_attr.grh.traffic_class = 0; - attr.ah_attr.grh.dgid.global.subnet_prefix = (dest->gid.global.subnet_prefix); - attr.ah_attr.grh.dgid.global.interface_id = (dest->gid.global.interface_id); - attr.ah_attr.grh.sgid_index = sgid_id; + attr.ah_attr.is_global = 1; + attr.ah_attr.grh.hop_limit = 255; + attr.ah_attr.grh.flow_label = 0; + attr.ah_attr.grh.traffic_class = 0; + attr.ah_attr.grh.dgid.global.subnet_prefix = (dest->gid.global.subnet_prefix); + attr.ah_attr.grh.dgid.global.interface_id = (dest->gid.global.interface_id); + attr.ah_attr.grh.sgid_index = sgid_id; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.port_num = port; - attr.ah_attr.src_path_bits = 0; - attr.ah_attr.port_num = port; + if (modify_qp_to_rtr(ctx->qp, &attr) != 0) { + return QpStatus::kModifyToRTRFailed; + } - if (modify_qp_to_rtr(ctx->qp, &attr) != 0) { - return QpStatus::kModifyToRTRFailed; - } + int qp_timeout = KVCacheConfig::getInstance().get_ib_timeout(); + attr.qp_state = IBV_QPS_RTS; + attr.timeout = qp_timeout; + attr.retry_cnt = 7; + attr.rnr_retry = 7; + attr.sq_psn = 0; + attr.max_rd_atomic = 1; - int qp_timeout = KVCacheConfig::getInstance().get_ib_timeout(); - attr.qp_state = IBV_QPS_RTS; - attr.timeout = qp_timeout; - attr.retry_cnt = 7; - attr.rnr_retry = 7; - attr.sq_psn = 0; - attr.max_rd_atomic = 1; + if (modify_rtr_to_rts(ctx->qp, &attr) != 0) { + return QpStatus::kModifyToRTSFailed; + } - if (modify_rtr_to_rts(ctx->qp, &attr) != 0) { - return QpStatus::kModifyToRTSFailed; - } - - LOGD("QP successfully transitioned to RTS state"); - return QpStatus::kSuccess; + LOGD("QP successfully transitioned to RTS state"); + return QpStatus::kSuccess; } -static std::shared_ptr client_exch_dest( - struct RdmaContext *ctx, - const std::string &dst_ip, - int port, - const QpInfo *my_dest) { - struct addrinfo hints = {}; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; +static std::shared_ptr client_exch_dest(struct RdmaContext *ctx, + const std::string &dst_ip, + int port, + const QpInfo *my_dest) { + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; - struct addrinfo *res = nullptr; - std::ostringstream service; - service << port; + struct addrinfo *res = nullptr; + std::ostringstream service; + service << port; - int ret = getaddrinfo(dst_ip.c_str(), service.str().c_str(), &hints, &res); - if (ret != 0) { - ERR("getaddrinfo failed for %s:%d - %s", dst_ip.c_str(), port, gai_strerror(ret)); - return nullptr; - } - - int sockfd = -1; - for (struct addrinfo *ai = res; ai; ai = ai->ai_next) { - sockfd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); - if (sockfd < 0) { - WARN("Socket creation failed: %s", strerror(errno)); - continue; - } - - int enable = 1; - setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)); - int keep_idle = 10, keep_intvl = 5, keep_cnt = 3; - setsockopt(sockfd, SOL_TCP, TCP_KEEPIDLE, &keep_idle, sizeof(keep_idle)); - setsockopt(sockfd, SOL_TCP, TCP_KEEPINTVL, &keep_intvl, sizeof(keep_intvl)); - setsockopt(sockfd, SOL_TCP, TCP_KEEPCNT, &keep_cnt, sizeof(keep_cnt)); - - if (connect(sockfd, ai->ai_addr, ai->ai_addrlen) == 0) { - break; // Connected - } - - WARN("Connect failed: %s", strerror(errno)); - close(sockfd); - sockfd = -1; - } - - freeaddrinfo(res); + int ret = getaddrinfo(dst_ip.c_str(), service.str().c_str(), &hints, &res); + if (ret != 0) { + ERR("getaddrinfo failed for %s:%d - %s", + dst_ip.c_str(), + port, + gai_strerror(ret)); + return nullptr; + } + int sockfd = -1; + for (struct addrinfo *ai = res; ai; ai = ai->ai_next) { + sockfd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); if (sockfd < 0) { - ERR("Unable to connect to %s:%d", dst_ip.c_str(), port); - return nullptr; + WARN("Socket creation failed: %s", strerror(errno)); + continue; } - ctx->sock_fd = sockfd; + int enable = 1; + setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)); + int keep_idle = 10, keep_intvl = 5, keep_cnt = 3; + setsockopt(sockfd, SOL_TCP, TCP_KEEPIDLE, &keep_idle, sizeof(keep_idle)); + setsockopt(sockfd, SOL_TCP, TCP_KEEPINTVL, &keep_intvl, sizeof(keep_intvl)); + setsockopt(sockfd, SOL_TCP, TCP_KEEPCNT, &keep_cnt, sizeof(keep_cnt)); - char buffer[QpInfo::size] = {}; - QpInfo(*my_dest).serialize(buffer); - - if (write(sockfd, buffer, QpInfo::size) != QpInfo::size) { - WARN("Failed to send local QP info to %s", dst_ip.c_str()); - close(sockfd); - return nullptr; + if (connect(sockfd, ai->ai_addr, ai->ai_addrlen) == 0) { + break; // Connected } - if (read(sockfd, buffer, QpInfo::size) != QpInfo::size) { - WARN("Failed to receive remote QP info from %s", dst_ip.c_str()); - close(sockfd); - return nullptr; - } + WARN("Connect failed: %s", strerror(errno)); + close(sockfd); + sockfd = -1; + } - // I think no need to check memory allocate, because once allocate failed, - // that's mean the process encountering OOM, let it crash then check whether - // the code logic has memory leak or not. - auto rem_dest = std::make_shared(); - rem_dest->deserialize(buffer); - return rem_dest; + freeaddrinfo(res); + + if (sockfd < 0) { + ERR("Unable to connect to %s:%d", dst_ip.c_str(), port); + return nullptr; + } + + ctx->sock_fd = sockfd; + + char buffer[QpInfo::size] = {}; + QpInfo(*my_dest).serialize(buffer); + + if (write(sockfd, buffer, QpInfo::size) != QpInfo::size) { + WARN("Failed to send local QP info to %s", dst_ip.c_str()); + close(sockfd); + return nullptr; + } + + if (read(sockfd, buffer, QpInfo::size) != QpInfo::size) { + WARN("Failed to receive remote QP info from %s", dst_ip.c_str()); + close(sockfd); + return nullptr; + } + + // I think no need to check memory allocate, because once allocate failed, + // that's mean the process encountering OOM, let it crash then check whether + // the code logic has memory leak or not. + auto rem_dest = std::make_shared(); + rem_dest->deserialize(buffer); + return rem_dest; } -bool poll_cq_with_timeout(struct RdmaContext *ctx, int timeout_seconds, int cqe_count) { - struct timespec start_time, current_time; - struct ibv_wc *wc_array = (struct ibv_wc *)malloc(cqe_count * sizeof(struct ibv_wc)); +bool poll_cq_with_timeout(struct RdmaContext *ctx, + int timeout_seconds, + int cqe_count) { + struct timespec start_time, current_time; + struct ibv_wc *wc_array = + (struct ibv_wc *)malloc(cqe_count * sizeof(struct ibv_wc)); - if (!wc_array) { - ERR("Failed to allocate memory for WC array"); - return false; + if (!wc_array) { + ERR("Failed to allocate memory for WC array"); + return false; + } + + clock_gettime(CLOCK_MONOTONIC, &start_time); + + while (1) { + int poll_result = ibv_poll_cq(ctx->cq, cqe_count, wc_array); + + if (poll_result < 0) { + ERR("ibv_poll_cq failed with return value %d", poll_result); + free(wc_array); + return false; + } else if (poll_result > 0) { + for (int i = 0; i < poll_result; ++i) { + if (wc_array[i].status == IBV_WC_SUCCESS) { + LOGD("Work completion %d successful", poll_result); + } else { + LOGD("Work completion %d status is %d (%s)", + poll_result, + wc_array[i].status, + ibv_wc_status_str(wc_array[i].status)); + } + } + free(wc_array); + return true; } - clock_gettime(CLOCK_MONOTONIC, &start_time); - - while (1) { - int poll_result = ibv_poll_cq(ctx->cq, cqe_count, wc_array); - - if (poll_result < 0) { - ERR("ibv_poll_cq failed with return value %d", poll_result); - free(wc_array); - return false; - } else if (poll_result > 0) { - for (int i = 0; i < poll_result; ++i) { - if (wc_array[i].status == IBV_WC_SUCCESS) { - LOGD("Work completion %d successful", poll_result); - } else { - LOGD("Work completion %d status is %d (%s)", - poll_result, wc_array[i].status, ibv_wc_status_str(wc_array[i].status)); - } - } - free(wc_array); - return true; - } - - clock_gettime(CLOCK_MONOTONIC, ¤t_time); - if ((current_time.tv_sec - start_time.tv_sec) >= timeout_seconds) { - ERR("Timeout occurred after %d seconds", timeout_seconds); - free(wc_array); - return false; - } + clock_gettime(CLOCK_MONOTONIC, ¤t_time); + if ((current_time.tv_sec - start_time.tv_sec) >= timeout_seconds) { + ERR("Timeout occurred after %d seconds", timeout_seconds); + free(wc_array); + return false; } - return true; + } + return true; } -bool clear_qp_info(struct RdmaContext* ctx) { - if (!ctx) { - ERR("RdmaContext pointer is null."); - return false; +bool clear_qp_info(struct RdmaContext *ctx) { + if (!ctx) { + ERR("RdmaContext pointer is null."); + return false; + } + + bool success = true; + + if (ctx->qp) { + if (ibv_destroy_qp(ctx->qp)) { + ERR("Failed to destroy QP."); + success = false; } + } - bool success = true; - - if (ctx->qp) { - if (ibv_destroy_qp(ctx->qp)) { - ERR("Failed to destroy QP."); - success = false; - } + if (ctx->cq) { + if (ibv_destroy_cq(ctx->cq)) { + ERR("Failed to deallocate cq Domain."); + success = false; } + } - if (ctx->cq) { - if (ibv_destroy_cq(ctx->cq)) { - ERR("Failed to deallocate cq Domain."); - success = false; - } + if (ctx->channel) { + if (ibv_destroy_comp_channel(ctx->channel)) { + ERR("Failed to destroy Completion Channel."); + success = false; } + } - if (ctx->channel) { - if (ibv_destroy_comp_channel(ctx->channel)) { - ERR("Failed to destroy Completion Channel."); - success = false; - } - } - - return success; + return success; } -struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd) { - struct RdmaContext* ctx = new RdmaContext(); - memset(ctx, 0, sizeof(struct RdmaContext)); - struct ibv_qp_init_attr qpInitAttr = {}; - ctx->context = ib_dev->context; +struct RdmaContext *create_qp(struct IbDeviceInfo *ib_dev, + struct ibv_pd **g_pd) { + struct RdmaContext *ctx = new RdmaContext(); + memset(ctx, 0, sizeof(struct RdmaContext)); + struct ibv_qp_init_attr qpInitAttr = {}; + ctx->context = ib_dev->context; + if (*g_pd == NULL) { + *g_pd = ibv_alloc_pd(ctx->context); if (*g_pd == NULL) { - *g_pd = ibv_alloc_pd(ctx->context); - if (*g_pd == NULL) { - ERR("failed to allocate protection domain"); - free(ctx->context); - return NULL; - } + ERR("failed to allocate protection domain"); + free(ctx->context); + return NULL; } - ctx->pd = *g_pd; + } + ctx->pd = *g_pd; - // Create completion channel - ctx->channel = ibv_create_comp_channel(ctx->context); - if (!ctx->channel) { - ERR("Failed to create completion channel: %s", strerror(errno)); - delete ctx; - return NULL; - } + // Create completion channel + ctx->channel = ibv_create_comp_channel(ctx->context); + if (!ctx->channel) { + ERR("Failed to create completion channel: %s", strerror(errno)); + delete ctx; + return NULL; + } - // Create completion queue - ctx->cq = ibv_create_cq(ctx->context, 4096, ctx, ctx->channel, 0); - if (!ctx->cq) { - ERR("Failed to create completion queue: %s", strerror(errno)); - ibv_destroy_comp_channel(ctx->channel); - delete ctx; - return NULL; - } + // Create completion queue + ctx->cq = ibv_create_cq(ctx->context, 4096, ctx, ctx->channel, 0); + if (!ctx->cq) { + ERR("Failed to create completion queue: %s", strerror(errno)); + ibv_destroy_comp_channel(ctx->channel); + delete ctx; + return NULL; + } - // Request completion notifications - if (ibv_req_notify_cq(ctx->cq, 0)) { - ERR("Failed to request CQ notifications: %s", strerror(errno)); - ibv_destroy_cq(ctx->cq); - ibv_destroy_comp_channel(ctx->channel); - delete ctx; - return NULL; - } + // Request completion notifications + if (ibv_req_notify_cq(ctx->cq, 0)) { + ERR("Failed to request CQ notifications: %s", strerror(errno)); + ibv_destroy_cq(ctx->cq); + ibv_destroy_comp_channel(ctx->channel); + delete ctx; + return NULL; + } - // Initialize QP attributes - qpInitAttr.send_cq = ctx->cq; - qpInitAttr.recv_cq = ctx->cq; - qpInitAttr.qp_type = IBV_QPT_RC; - qpInitAttr.cap.max_send_wr = 4096; - qpInitAttr.cap.max_recv_wr = 4096; - qpInitAttr.cap.max_send_sge = 1; - qpInitAttr.cap.max_recv_sge = 1; - qpInitAttr.cap.max_inline_data = 0; + // Initialize QP attributes + qpInitAttr.send_cq = ctx->cq; + qpInitAttr.recv_cq = ctx->cq; + qpInitAttr.qp_type = IBV_QPT_RC; + qpInitAttr.cap.max_send_wr = 4096; + qpInitAttr.cap.max_recv_wr = 4096; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_inline_data = 0; - // Create queue pair - ctx->qp = ibv_create_qp(ctx->pd, &qpInitAttr); - if (!ctx->qp) { - ERR("Failed to create queue pair: %s", strerror(errno)); - ibv_destroy_cq(ctx->cq); - ibv_destroy_comp_channel(ctx->channel); - delete ctx; - return NULL; - } + // Create queue pair + ctx->qp = ibv_create_qp(ctx->pd, &qpInitAttr); + if (!ctx->qp) { + ERR("Failed to create queue pair: %s", strerror(errno)); + ibv_destroy_cq(ctx->cq); + ibv_destroy_comp_channel(ctx->channel); + delete ctx; + return NULL; + } - // Modify QP to INIT state - struct ibv_qp_attr qpAttr = {}; - qpAttr.qp_state = IBV_QPS_INIT; - qpAttr.pkey_index = 0; - qpAttr.port_num = 1; - qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + // Modify QP to INIT state + struct ibv_qp_attr qpAttr = {}; + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = 0; + qpAttr.port_num = 1; + qpAttr.qp_access_flags = + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; - int ret = modify_qp_to_init(ctx->qp, &qpAttr); - if (ret != 0) { - ERR("Failed to modify QP to INIT state: %s (ret=%d)", strerror(errno), ret); - ibv_destroy_qp(ctx->qp); - ibv_destroy_cq(ctx->cq); - ibv_destroy_comp_channel(ctx->channel); - delete ctx; - return NULL; - } + int ret = modify_qp_to_init(ctx->qp, &qpAttr); + if (ret != 0) { + ERR("Failed to modify QP to INIT state: %s (ret=%d)", strerror(errno), ret); + ibv_destroy_qp(ctx->qp); + ibv_destroy_cq(ctx->cq); + ibv_destroy_comp_channel(ctx->channel); + delete ctx; + return NULL; + } - INFO("Successfully created QP 0x%x on device %s", - ctx->qp->qp_num, ib_dev->devName); + INFO("Successfully created QP 0x%x on device %s", + ctx->qp->qp_num, + ib_dev->devName); - return ctx; + return ctx; } /** @@ -579,77 +611,79 @@ struct RdmaContext* create_qp(struct IbDeviceInfo* ib_dev, struct ibv_pd** g_pd) * @param dst_ip Destination IP address * @return true on success, false on failure */ -bool client_exchange_destinations( - struct RdmaContext* ctx, - int ib_port, - unsigned int port, - int gidx, - const std::string& dst_ip) { +bool client_exchange_destinations(struct RdmaContext *ctx, + int ib_port, + unsigned int port, + int gidx, + const std::string &dst_ip) { + if (!ctx || !ctx->context || !ctx->qp) { + ERR("Invalid RDMA context or QP not initialized"); + return false; + } - if (!ctx || !ctx->context || !ctx->qp) { - ERR("Invalid RDMA context or QP not initialized"); - return false; + LOGD("Exchanging destination info with %s:%u", dst_ip.c_str(), port); + + // Get local QP information + struct QpInfo my_dest = {}; + if (get_port_info(ctx->context, ib_port, &ctx->portinfo)) { + ERR("Failed to get port info for port %d", ib_port); + return false; + } + + my_dest.lid = ctx->portinfo.lid; + my_dest.mtu = ctx->portinfo.active_mtu; + + // Validate LID for InfiniBand + if (ctx->portinfo.link_layer != IBV_LINK_LAYER_ETHERNET && !my_dest.lid) { + ERR("Invalid LID 0x%04x for non-Ethernet link layer", my_dest.lid); + return false; + } + + // Get GID if specified + if (gidx >= 0) { + if (ibv_query_gid(ctx->context, ib_port, gidx, &my_dest.gid)) { + ERR("Failed to query GID for index %d on port %d", gidx, ib_port); + return false; } + } else { + memset(&my_dest.gid, 0, sizeof(my_dest.gid)); + } - LOGD("Exchanging destination info with %s:%u", dst_ip.c_str(), port); + my_dest.qpn = ctx->qp->qp_num; + my_dest.psn = lrand48() & 0xffffff; - // Get local QP information - struct QpInfo my_dest = {}; - if (get_port_info(ctx->context, ib_port, &ctx->portinfo)) { - ERR("Failed to get port info for port %d", ib_port); - return false; - } + // Log local address info + char gid_str[33] = {0}; + inet_ntop(AF_INET6, &my_dest.gid, gid_str, sizeof(gid_str)); - my_dest.lid = ctx->portinfo.lid; - my_dest.mtu = ctx->portinfo.active_mtu; + if (dst_ip.empty()) { + ERR("Empty destination IP address"); + return false; + } - // Validate LID for InfiniBand - if (ctx->portinfo.link_layer != IBV_LINK_LAYER_ETHERNET && !my_dest.lid) { - ERR("Invalid LID 0x%04x for non-Ethernet link layer", my_dest.lid); - return false; - } + // Exchange destination info with remote + auto rem_dest = client_exch_dest(ctx, dst_ip, port, &my_dest); + if (!rem_dest) { + ERR("Failed to exchange destination info with %s:%u", dst_ip.c_str(), port); + return false; + } - // Get GID if specified - if (gidx >= 0) { - if (ibv_query_gid(ctx->context, ib_port, gidx, &my_dest.gid)) { - ERR("Failed to query GID for index %d on port %d", gidx, ib_port); - return false; - } - } else { - memset(&my_dest.gid, 0, sizeof(my_dest.gid)); - } + LOGD("Remote address - LID: 0x%04x, QPN: 0x%06x, PSN: 0x%06x, Mtu: %u", + rem_dest->lid, + rem_dest->qpn, + rem_dest->psn, + rem_dest->mtu); - my_dest.qpn = ctx->qp->qp_num; - my_dest.psn = lrand48() & 0xffffff; + // Modify QP to RTS state + if (modify_qp_to_rts(ctx, ib_port, my_dest.psn, rem_dest.get(), gidx) != + QpStatus::kSuccess) { + ERR("Failed to modify QP 0x%x to RTS state", ctx->qp->qp_num); + return false; + } - // Log local address info - char gid_str[33] = {0}; - inet_ntop(AF_INET6, &my_dest.gid, gid_str, sizeof(gid_str)); + LOGD("Successfully established connection to %s:%u", dst_ip.c_str(), port); - if (dst_ip.empty()) { - ERR("Empty destination IP address"); - return false; - } - - // Exchange destination info with remote - auto rem_dest = client_exch_dest(ctx, dst_ip, port, &my_dest); - if (!rem_dest) { - ERR("Failed to exchange destination info with %s:%u", dst_ip.c_str(), port); - return false; - } - - LOGD("Remote address - LID: 0x%04x, QPN: 0x%06x, PSN: 0x%06x, Mtu: %u", - rem_dest->lid, rem_dest->qpn, rem_dest->psn, rem_dest->mtu); - - // Modify QP to RTS state - if (modify_qp_to_rts(ctx, ib_port, my_dest.psn, rem_dest.get(), gidx) != QpStatus::kSuccess) { - ERR("Failed to modify QP 0x%x to RTS state", ctx->qp->qp_num); - return false; - } - - LOGD("Successfully established connection to %s:%u", dst_ip.c_str(), port); - - return true; + return true; } /** @@ -659,13 +693,17 @@ bool client_exchange_destinations( * @param is_client True if this is the client side operation, false for server * @return true on success, false on failure */ -template -bool exchange_mr_vector(struct RdmaContext *ctx, std::vector& data_list, bool is_client) { - if (is_client) { - return client_receive_memory_region(ctx, data_list.data(), data_list.size() * sizeof(T)); - } else { - return server_send_memory_region(ctx, data_list.data(), data_list.size() * sizeof(T)); - } +template +bool exchange_mr_vector(struct RdmaContext *ctx, + std::vector &data_list, + bool is_client) { + if (is_client) { + return client_receive_memory_region( + ctx, data_list.data(), data_list.size() * sizeof(T)); + } else { + return server_send_memory_region( + ctx, data_list.data(), data_list.size() * sizeof(T)); + } } /** @@ -675,31 +713,31 @@ bool exchange_mr_vector(struct RdmaContext *ctx, std::vector& data_list, bool * @return true on success, false on failure */ bool client_exchange_mr(struct RdmaContext *ctx) { - LOGD("verb client exchange mr: start"); + LOGD("verb client exchange mr: start"); - if (ctx->conn.layer_number <= 0) { - ERR("Invalid layer number: %d", ctx->conn.layer_number); - return false; - } + if (ctx->conn.layer_number <= 0) { + ERR("Invalid layer number: %d", ctx->conn.layer_number); + return false; + } - auto layer_num = ctx->conn.layer_number; - std::vector key_ptrs(layer_num); - std::vector key_rkeys(layer_num); - std::vector val_ptrs(layer_num); - std::vector val_rkeys(layer_num); + auto layer_num = ctx->conn.layer_number; + std::vector key_ptrs(layer_num); + std::vector key_rkeys(layer_num); + std::vector val_ptrs(layer_num); + std::vector val_rkeys(layer_num); - if (!exchange_mr_vector(ctx, key_ptrs, true)) return false; - if (!exchange_mr_vector(ctx, key_rkeys, true)) return false; - if (!exchange_mr_vector(ctx, val_ptrs, true)) return false; - if (!exchange_mr_vector(ctx, val_rkeys, true)) return false; + if (!exchange_mr_vector(ctx, key_ptrs, true)) return false; + if (!exchange_mr_vector(ctx, key_rkeys, true)) return false; + if (!exchange_mr_vector(ctx, val_ptrs, true)) return false; + if (!exchange_mr_vector(ctx, val_rkeys, true)) return false; - for (int i = 0; i < layer_num; ++i) { - ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]); - ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]); - ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]); - ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]); - } - return true; + for (int i = 0; i < layer_num; ++i) { + ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]); + ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]); + ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]); + ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]); + } + return true; } /** @@ -709,49 +747,49 @@ bool client_exchange_mr(struct RdmaContext *ctx) { * @return true on success, false on failure */ bool server_exchange_mr(struct RdmaContext *ctx) { - LOGD("verbs server exchange mr: start"); + LOGD("verbs server exchange mr: start"); - if (ctx->conn.layer_number <= 0) { - ERR("Invalid layer number: %d", ctx->conn.layer_number); - return false; - } + if (ctx->conn.layer_number <= 0) { + ERR("Invalid layer number: %d", ctx->conn.layer_number); + return false; + } - auto layer_num = ctx->conn.layer_number; - auto& key_mrs = ctx->conn.write_cache_key_server_mr_list; - auto& val_mrs = ctx->conn.write_cache_value_server_mr_list; + auto layer_num = ctx->conn.layer_number; + auto &key_mrs = ctx->conn.write_cache_key_server_mr_list; + auto &val_mrs = ctx->conn.write_cache_value_server_mr_list; - // Verify that server memory regions are properly initialized - if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) { - ERR("server write cache memory region size error"); - return false; - } + // Verify that server memory regions are properly initialized + if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) { + ERR("server write cache memory region size error"); + return false; + } - // Prepare memory region information to send - std::vector send_key_ptrs; - std::vector send_key_rkeys; - std::vector send_val_ptrs; - std::vector send_val_rkeys; + // Prepare memory region information to send + std::vector send_key_ptrs; + std::vector send_key_rkeys; + std::vector send_val_ptrs; + std::vector send_val_rkeys; - send_key_ptrs.reserve(layer_num); - send_key_rkeys.reserve(layer_num); - send_val_ptrs.reserve(layer_num); - send_val_rkeys.reserve(layer_num); + send_key_ptrs.reserve(layer_num); + send_key_rkeys.reserve(layer_num); + send_val_ptrs.reserve(layer_num); + send_val_rkeys.reserve(layer_num); - // Collect memory region information from local MRs - for (int i = 0; i < layer_num; ++i) { - send_key_ptrs.push_back(reinterpret_cast(key_mrs[i]->addr)); - send_key_rkeys.push_back(key_mrs[i]->rkey); - send_val_ptrs.push_back(reinterpret_cast(val_mrs[i]->addr)); - send_val_rkeys.push_back(val_mrs[i]->rkey); - } + // Collect memory region information from local MRs + for (int i = 0; i < layer_num; ++i) { + send_key_ptrs.push_back(reinterpret_cast(key_mrs[i]->addr)); + send_key_rkeys.push_back(key_mrs[i]->rkey); + send_val_ptrs.push_back(reinterpret_cast(val_mrs[i]->addr)); + send_val_rkeys.push_back(val_mrs[i]->rkey); + } - // Send all vectors to client - if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false; - if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false; - if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false; - if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false; + // Send all vectors to client + if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false; + if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false; + if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false; + if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false; - return true; + return true; } /** @@ -762,100 +800,105 @@ bool server_exchange_mr(struct RdmaContext *ctx) { * @param byte_num Size of the memory region in bytes * @return true on success, false on failure */ -bool server_send_memory_region(struct RdmaContext *ctx, void *local_mr, int byte_num) { - // Register the memory region for sending - ctx->conn.send_mr = ibv_reg_mr(ctx->pd, local_mr, byte_num, 0); - if (ctx->conn.send_mr == NULL) { - ERR("ibv_reg_mr failed"); - return false; - } +bool server_send_memory_region(struct RdmaContext *ctx, + void *local_mr, + int byte_num) { + // Register the memory region for sending + ctx->conn.send_mr = ibv_reg_mr(ctx->pd, local_mr, byte_num, 0); + if (ctx->conn.send_mr == NULL) { + ERR("ibv_reg_mr failed"); + return false; + } - // Prepare the send work request - struct ibv_send_wr wr, *bad_wr = NULL; - struct ibv_sge sge; + // Prepare the send work request + struct ibv_send_wr wr, *bad_wr = NULL; + struct ibv_sge sge; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = reinterpret_cast(&ctx->conn); - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.send_flags = IBV_SEND_SIGNALED; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(&ctx->conn); + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.send_flags = IBV_SEND_SIGNALED; - // Set up scatter-gather element - sge.addr = (uintptr_t)local_mr; - sge.length = byte_num; - sge.lkey = ctx->conn.send_mr->lkey; + // Set up scatter-gather element + sge.addr = (uintptr_t)local_mr; + sge.length = byte_num; + sge.lkey = ctx->conn.send_mr->lkey; - // Post the send request - int ret = ibv_post_send(ctx->qp, &wr, &bad_wr); - if (ret) { - ERR("ibv_post_send failed"); - ibv_dereg_mr(ctx->conn.send_mr); - return false; - } - - // Wait for completion - struct ibv_wc wc; - ctx->conn.wc_count = 0; - ctx->conn.wc_target_count = 0; - - if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { - return false; - } - - // Deregister the memory region + // Post the send request + int ret = ibv_post_send(ctx->qp, &wr, &bad_wr); + if (ret) { + ERR("ibv_post_send failed"); ibv_dereg_mr(ctx->conn.send_mr); - return true; + return false; + } + + // Wait for completion + struct ibv_wc wc; + ctx->conn.wc_count = 0; + ctx->conn.wc_target_count = 0; + + if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { + return false; + } + + // Deregister the memory region + ibv_dereg_mr(ctx->conn.send_mr); + return true; } /** * Receive memory region information on the client side * * @param ctx The RDMA context - * @param remote_mr Pointer to the buffer where remote memory region info will be stored + * @param remote_mr Pointer to the buffer where remote memory region info will + * be stored * @param byte_num Size of the memory region in bytes * @return true on success, false on failure */ -bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int byte_num) { - // Register memory region for receiving data - int access_flags = IBV_ACCESS_LOCAL_WRITE; - ctx->conn.recv_mr = ibv_reg_mr(ctx->pd, remote_mr, byte_num, access_flags); - if (ctx->conn.recv_mr == NULL) { - ERR("ibv_reg_mr failed for receive region"); - return false; - } +bool client_receive_memory_region(struct RdmaContext *ctx, + void *remote_mr, + int byte_num) { + // Register memory region for receiving data + int access_flags = IBV_ACCESS_LOCAL_WRITE; + ctx->conn.recv_mr = ibv_reg_mr(ctx->pd, remote_mr, byte_num, access_flags); + if (ctx->conn.recv_mr == NULL) { + ERR("ibv_reg_mr failed for receive region"); + return false; + } - // Prepare the receive work request - struct ibv_recv_wr wr, *bad_wr = NULL; - struct ibv_sge sge; + // Prepare the receive work request + struct ibv_recv_wr wr, *bad_wr = NULL; + struct ibv_sge sge; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = reinterpret_cast(&ctx->conn); - wr.sg_list = &sge; - wr.num_sge = 1; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(&ctx->conn); + wr.sg_list = &sge; + wr.num_sge = 1; - // Set up scatter-gather element - sge.addr = (uintptr_t)remote_mr; - sge.length = byte_num; - sge.lkey = ctx->conn.recv_mr->lkey; + // Set up scatter-gather element + sge.addr = (uintptr_t)remote_mr; + sge.length = byte_num; + sge.lkey = ctx->conn.recv_mr->lkey; - // Post the receive request - int ret = ibv_post_recv(ctx->qp, &wr, &bad_wr); - if (ret) { - ibv_dereg_mr(ctx->conn.recv_mr); - return false; - } - - // Poll completion queue with timeout - ctx->conn.wc_count = 0; - ctx->conn.wc_target_count = 0; - if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { - return false; - } - - // Deregister memory region + // Post the receive request + int ret = ibv_post_recv(ctx->qp, &wr, &bad_wr); + if (ret) { ibv_dereg_mr(ctx->conn.recv_mr); - return true; + return false; + } + + // Poll completion queue with timeout + ctx->conn.wc_count = 0; + ctx->conn.wc_target_count = 0; + if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { + return false; + } + + // Deregister memory region + ibv_dereg_mr(ctx->conn.recv_mr); + return true; } /** @@ -865,183 +908,187 @@ bool client_receive_memory_region(struct RdmaContext *ctx, void *remote_mr, int * @return The socket file descriptor on success, -1 on failure */ int setup_listening_socket(int port) { - int sockfd = -1; - struct addrinfo hints = {0}; + int sockfd = -1; + struct addrinfo hints = {0}; - // Set up hints for getaddrinfo - hints.ai_flags = AI_PASSIVE; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; + // Set up hints for getaddrinfo + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; - struct addrinfo *res = nullptr; + struct addrinfo *res = nullptr; - // Convert port to string for getaddrinfo - std::ostringstream service; - service << port; + // Convert port to string for getaddrinfo + std::ostringstream service; + service << port; - // Get address info for the specified port - int n = getaddrinfo(nullptr, service.str().c_str(), &hints, &res); - if (n != 0) { - ERR("getaddrinfo failed for port %d: %s", port, gai_strerror(n)); - return -1; - } + // Get address info for the specified port + int n = getaddrinfo(nullptr, service.str().c_str(), &hints, &res); + if (n != 0) { + ERR("getaddrinfo failed for port %d: %s", port, gai_strerror(n)); + return -1; + } - // Check if a specific network interface is specified - const char *ifname = KVCacheConfig::getInstance().get_socket_interface(); - // Try each address until we successfully bind to one - for (struct addrinfo *t = res; t; t = t->ai_next) { - // Create socket - sockfd = socket(t->ai_family, t->ai_socktype, t->ai_protocol); - if (sockfd < 0) { - ERR("Socket creation failed: %s", strerror(errno)); - continue; - } - - // Bind to specific interface if requested - if (ifname) { - WARN("Binding socket to the specified interface: %s", ifname); - if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname)) < 0) { - ERR("Failed to bind to interface %s - %s", ifname, strerror(errno)); - close(sockfd); - continue; - } - } - - // Enable address reuse - n = 1; - setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &n, sizeof(n)); - - // Attempt to bind to the address - if (bind(sockfd, t->ai_addr, t->ai_addrlen) == 0) { - break; // Successful bind - } else { - WARN("Bind failed: %s", strerror(errno)); - close(sockfd); - sockfd = -1; - } - } - - // Free the address list - freeaddrinfo(res); - - // Check if binding was successful + // Check if a specific network interface is specified + const char *ifname = KVCacheConfig::getInstance().get_socket_interface(); + // Try each address until we successfully bind to one + for (struct addrinfo *t = res; t; t = t->ai_next) { + // Create socket + sockfd = socket(t->ai_family, t->ai_socktype, t->ai_protocol); if (sockfd < 0) { - ERR("Couldn't bind to any address on port %d", port); - return -1; + ERR("Socket creation failed: %s", strerror(errno)); + continue; } - // Start listening for connections - if (listen(sockfd, 4096) < 0) { - ERR("Failed to listen on port %d: %s", port, strerror(errno)); + // Bind to specific interface if requested + if (ifname) { + WARN("Binding socket to the specified interface: %s", ifname); + if (setsockopt( + sockfd, SOL_SOCKET, SO_BINDTODEVICE, ifname, strlen(ifname)) < + 0) { + ERR("Failed to bind to interface %s - %s", ifname, strerror(errno)); close(sockfd); - return -1; + continue; + } } - // Set socket to non-blocking mode - int flags = fcntl(sockfd, F_GETFL, 0); - int ret = fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); - if (ret < 0) { - ERR("Failed to set non-blocking mode on event channel"); - close(sockfd); - return -1; - } + // Enable address reuse + n = 1; + setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &n, sizeof(n)); - // Enable TCP keep-alive - int enable = 1; - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)) < 0) { - ERR("Failed to enable TCP keep-alive on socket: %s", strerror(errno)); - close(sockfd); - return -1; + // Attempt to bind to the address + if (bind(sockfd, t->ai_addr, t->ai_addrlen) == 0) { + break; // Successful bind + } else { + WARN("Bind failed: %s", strerror(errno)); + close(sockfd); + sockfd = -1; } + } - return sockfd; + // Free the address list + freeaddrinfo(res); + + // Check if binding was successful + if (sockfd < 0) { + ERR("Couldn't bind to any address on port %d", port); + return -1; + } + + // Start listening for connections + if (listen(sockfd, 4096) < 0) { + ERR("Failed to listen on port %d: %s", port, strerror(errno)); + close(sockfd); + return -1; + } + + // Set socket to non-blocking mode + int flags = fcntl(sockfd, F_GETFL, 0); + int ret = fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); + if (ret < 0) { + ERR("Failed to set non-blocking mode on event channel"); + close(sockfd); + return -1; + } + + // Enable TCP keep-alive + int enable = 1; + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enable, sizeof(enable)) < + 0) { + ERR("Failed to enable TCP keep-alive on socket: %s", strerror(errno)); + close(sockfd); + return -1; + } + + return sockfd; } int configure_epoll(int sockfd) { - int epollfd = epoll_create1(0); - if (epollfd == -1) { - ERR("epoll_create1"); - } + int epollfd = epoll_create1(0); + if (epollfd == -1) { + ERR("epoll_create1"); + } - // Initialize epoll for the listening socket - struct epoll_event ev; - ev.events = EPOLLIN | EPOLLOUT | EPOLLERR; - ev.data.fd = sockfd; - if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev) == -1) { - ERR("Failed to add listening socket to epoll"); - close(sockfd); - return -1; - } + // Initialize epoll for the listening socket + struct epoll_event ev; + ev.events = EPOLLIN | EPOLLOUT | EPOLLERR; + ev.data.fd = sockfd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, &ev) == -1) { + ERR("Failed to add listening socket to epoll"); + close(sockfd); + return -1; + } - return epollfd; + return epollfd; } static char *get_ip_by_ifname(const char *ifname) { - int fd = 0; - struct ifreq ifr; - struct sockaddr_in *ip_addr = NULL; + int fd = 0; + struct ifreq ifr; + struct sockaddr_in *ip_addr = NULL; - fd = socket(AF_INET, SOCK_DGRAM, 0); - if (fd <= 0) { - ERR("create socket failed: %s", strerror(errno)); - return NULL; - } - ifr.ifr_addr.sa_family = AF_INET; - strncpy(ifr.ifr_name, ifname, IFNAMSIZ - 1); - if (ioctl(fd, SIOCGIFADDR, &ifr) == 0) { - ip_addr = (struct sockaddr_in *)&ifr.ifr_addr; - close(fd); - return inet_ntoa(ip_addr->sin_addr); - } else { - WARN("get ip from %s failed, error: %s", ifr.ifr_name, strerror(errno)); - close(fd); - return NULL; - } + fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd <= 0) { + ERR("create socket failed: %s", strerror(errno)); + return NULL; + } + ifr.ifr_addr.sa_family = AF_INET; + strncpy(ifr.ifr_name, ifname, IFNAMSIZ - 1); + if (ioctl(fd, SIOCGIFADDR, &ifr) == 0) { + ip_addr = (struct sockaddr_in *)&ifr.ifr_addr; + close(fd); + return inet_ntoa(ip_addr->sin_addr); + } else { + WARN("get ip from %s failed, error: %s", ifr.ifr_name, strerror(errno)); + close(fd); + return NULL; + } } std::vector get_net_ifname() { - std::vector local_ip; - char ifnames[KVCACHE_RDMA_NIC_MAX_LEN + 1] = {0}; - const char *tmp = KVCacheConfig::getInstance().get_socket_interface(); - if (tmp) { - int cp_len = - strlen(tmp) > KVCACHE_RDMA_NIC_MAX_LEN ? KVCACHE_RDMA_NIC_MAX_LEN : strlen(tmp); - memcpy(ifnames, tmp, cp_len); - ifnames[cp_len] = '\0'; - } else { - WARN("no ifnames, local_ip: %lu", local_ip.size()); - return local_ip; - } - char *delim = (char *)","; - int i = 0; - WARN("ifnames: %s", ifnames); - - std::string rdma_addr[KVCACHE_RDMA_MAX_NICS]; - char *saveptr = nullptr; - char *pch = strtok_r(ifnames, delim, &saveptr); - while (pch != NULL) { - rdma_addr[i++] = std::string(pch); - pch = strtok_r(NULL, delim, &saveptr); - } - int dev_id = 0; - while (dev_id < i) { - if (rdma_addr[dev_id].length() != 0) { - char *ip = get_ip_by_ifname(rdma_addr[dev_id].c_str()); - if (ip) { - local_ip.push_back(std::string(ip)); - } - } - dev_id++; - } + std::vector local_ip; + char ifnames[KVCACHE_RDMA_NIC_MAX_LEN + 1] = {0}; + const char *tmp = KVCacheConfig::getInstance().get_socket_interface(); + if (tmp) { + int cp_len = strlen(tmp) > KVCACHE_RDMA_NIC_MAX_LEN + ? KVCACHE_RDMA_NIC_MAX_LEN + : strlen(tmp); + memcpy(ifnames, tmp, cp_len); + ifnames[cp_len] = '\0'; + } else { + WARN("no ifnames, local_ip: %lu", local_ip.size()); return local_ip; + } + char *delim = (char *)","; + int i = 0; + WARN("ifnames: %s", ifnames); + + std::string rdma_addr[KVCACHE_RDMA_MAX_NICS]; + char *saveptr = nullptr; + char *pch = strtok_r(ifnames, delim, &saveptr); + while (pch != NULL) { + rdma_addr[i++] = std::string(pch); + pch = strtok_r(NULL, delim, &saveptr); + } + int dev_id = 0; + while (dev_id < i) { + if (rdma_addr[dev_id].length() != 0) { + char *ip = get_ip_by_ifname(rdma_addr[dev_id].c_str()); + if (ip) { + local_ip.push_back(std::string(ip)); + } + } + dev_id++; + } + return local_ip; } Connection::~Connection() { - write_cache_key_server_mr_list.clear(); - write_cache_value_server_mr_list.clear(); - write_cache_key_remote_ptr_list.clear(); - write_cache_key_remote_rkey_list.clear(); - write_cache_value_remote_ptr_list.clear(); - write_cache_value_remote_rkey_list.clear(); - LOGD("delete Connection %s", url.c_str()); + write_cache_key_server_mr_list.clear(); + write_cache_value_server_mr_list.clear(); + write_cache_key_remote_ptr_list.clear(); + write_cache_key_remote_rkey_list.clear(); + write_cache_value_remote_ptr_list.clear(); + write_cache_value_remote_rkey_list.clear(); + LOGD("delete Connection %s", url.c_str()); } diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp index 3f2d21016..4e443872a 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp @@ -18,19 +18,19 @@ */ #include "kvcache_rdma.h" #include "kvcache_connection.h" -#include "util.h" #include "log.h" +#include "util.h" +#include #include #include -#include -#include #include +#include +#include #include +#include #include #include -#include -#include /** * @brief Construct a new RDMACommunicator object @@ -45,11 +45,13 @@ * * @throws std::runtime_error If initialization fails */ -RDMACommunicator::RDMACommunicator(std::string &role, int gpu_idx, - std::string &port, +RDMACommunicator::RDMACommunicator(std::string& role, + int gpu_idx, + std::string& port, std::vector local_key_cache, std::vector local_value_cache, - int block_number, int block_bytes) + int block_number, + int block_bytes) : splitwise_role(role), gpu_idx(gpu_idx), port(port), @@ -59,378 +61,389 @@ RDMACommunicator::RDMACommunicator(std::string &role, int gpu_idx, block_size_byte(block_bytes), RDMACommunicator_status(0), rdma_event_channel_epoll_fd(-1) { + try { + WARN("Initializing RDMA communicator for role: %s", role.c_str()); - try { - WARN("Initializing RDMA communicator for role: %s", role.c_str()); + // Step 1: Initialize KV cache config + KVCacheConfig::getInstance().displayConfiguration(); - // Step 1: Initialize KV cache config - KVCacheConfig::getInstance().displayConfiguration(); - - // Step 2: Initialize KV cache structure - // Validate and set number of layers - layer_number = static_cast(local_cache_key_ptr_layer_head_.size()); - if (layer_number <= 0) { - throw std::runtime_error("Invalid layer number"); - } - - // Step 2: Setup cache vectors and pointers - resize_vectors(); - assign_pointers(); - - // Step 3:Initialize the event channel - rdma_event_channel_epoll_fd = epoll_create1(EPOLL_CLOEXEC); - if (rdma_event_channel_epoll_fd < 0) { - throw std::runtime_error("Failed to create epoll fd: " + - std::string(strerror(errno))); - } - - // Start the server thread (if in decode role) - if (splitwise_role == "decode") { - std::thread server_thread([this]() { - try { - this->init_server(); - } catch (const std::exception& e) { - ERR("Server thread failed: %s", e.what()); - } - }); - server_thread.detach(); - } - - RDMACommunicator_status = 1; - INFO("RDMA communicator initialized successfully"); - } catch (const std::exception& e) { - ERR("Initialization failed: %s", e.what()); - if (rdma_event_channel_epoll_fd >= 0) { - close(rdma_event_channel_epoll_fd); - rdma_event_channel_epoll_fd = -1; - } - throw; + // Step 2: Initialize KV cache structure + // Validate and set number of layers + layer_number = static_cast(local_cache_key_ptr_layer_head_.size()); + if (layer_number <= 0) { + throw std::runtime_error("Invalid layer number"); } + + // Step 2: Setup cache vectors and pointers + resize_vectors(); + assign_pointers(); + + // Step 3:Initialize the event channel + rdma_event_channel_epoll_fd = epoll_create1(EPOLL_CLOEXEC); + if (rdma_event_channel_epoll_fd < 0) { + throw std::runtime_error("Failed to create epoll fd: " + + std::string(strerror(errno))); + } + + // Start the server thread (if in decode role) + if (splitwise_role == "decode") { + std::thread server_thread([this]() { + try { + this->init_server(); + } catch (const std::exception& e) { + ERR("Server thread failed: %s", e.what()); + } + }); + server_thread.detach(); + } + + RDMACommunicator_status = 1; + INFO("RDMA communicator initialized successfully"); + } catch (const std::exception& e) { + ERR("Initialization failed: %s", e.what()); + if (rdma_event_channel_epoll_fd >= 0) { + close(rdma_event_channel_epoll_fd); + rdma_event_channel_epoll_fd = -1; + } + throw; + } } void RDMACommunicator::resize_vectors() { - if (layer_number <= 0) { - throw std::runtime_error("Invalid layer number"); - } + if (layer_number <= 0) { + throw std::runtime_error("Invalid layer number"); + } - local_cache_key_ptr_per_layer.resize(layer_number); - local_cache_value_ptr_per_layer.resize(layer_number); + local_cache_key_ptr_per_layer.resize(layer_number); + local_cache_value_ptr_per_layer.resize(layer_number); } void RDMACommunicator::assign_pointers() { - // Validate block configuration - if (block_number <= 0 || block_size_byte <= 0) { - throw std::runtime_error("Invalid block configuration"); + // Validate block configuration + if (block_number <= 0 || block_size_byte <= 0) { + throw std::runtime_error("Invalid block configuration"); + } + + // Assign pointers for each layer and block + for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) { + // Validate layer head pointers + if (local_cache_key_ptr_layer_head_[layer_idx] == 0 || + local_cache_value_ptr_layer_head_[layer_idx] == 0) { + throw std::runtime_error("Invalid cache pointer for layer " + + std::to_string(layer_idx)); } - // Assign pointers for each layer and block - for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) { - // Validate layer head pointers - if (local_cache_key_ptr_layer_head_[layer_idx] == 0 || - local_cache_value_ptr_layer_head_[layer_idx] == 0) { - throw std::runtime_error("Invalid cache pointer for layer " + - std::to_string(layer_idx)); - } + // Resize block vectors for current layer + local_cache_key_ptr_per_layer[layer_idx].resize(block_number); + local_cache_value_ptr_per_layer[layer_idx].resize(block_number); - // Resize block vectors for current layer - local_cache_key_ptr_per_layer[layer_idx].resize(block_number); - local_cache_value_ptr_per_layer[layer_idx].resize(block_number); + // Calculate and assign block pointers + for (int block_idx = 0; block_idx < block_number; ++block_idx) { + local_cache_key_ptr_per_layer[layer_idx][block_idx] = + reinterpret_cast(local_cache_key_ptr_layer_head_[layer_idx] + + block_idx * block_size_byte); - // Calculate and assign block pointers - for (int block_idx = 0; block_idx < block_number; ++block_idx) { - local_cache_key_ptr_per_layer[layer_idx][block_idx] = - reinterpret_cast( - local_cache_key_ptr_layer_head_[layer_idx] + - block_idx * block_size_byte); - - local_cache_value_ptr_per_layer[layer_idx][block_idx] = - reinterpret_cast( - local_cache_value_ptr_layer_head_[layer_idx] + - block_idx * block_size_byte); - } + local_cache_value_ptr_per_layer[layer_idx][block_idx] = + reinterpret_cast(local_cache_value_ptr_layer_head_[layer_idx] + + block_idx * block_size_byte); } + } } void RDMACommunicator::validate_addr() { - if (main_ip_list.empty()){ - throw std::runtime_error("main_ip_list is empty"); - } else { - if (!main_ip_list.empty()) { - LOGD("Local main NIC addresses:"); - for (const auto& nic_ip : main_ip_list) { - LOGD("- %s", nic_ip.c_str()); - } - } + if (main_ip_list.empty()) { + throw std::runtime_error("main_ip_list is empty"); + } else { + if (!main_ip_list.empty()) { + LOGD("Local main NIC addresses:"); + for (const auto& nic_ip : main_ip_list) { + LOGD("- %s", nic_ip.c_str()); + } } + } } RDMACommunicator::~RDMACommunicator() { - try { - WARN("Destroying RDMA communicator"); + try { + WARN("Destroying RDMA communicator"); - // Mark as closed/shutdown state - RDMACommunicator_status = 0; + // Mark as closed/shutdown state + RDMACommunicator_status = 0; - // Clean up all connections - { - std::lock_guard lock(mutex_); - conn_map.clear(); - } - - // Clean up memory regions - auto deregister_mrs = [](std::vector& mrs, const char* name) { - for (auto* mr : mrs) { - if (mr && ibv_dereg_mr(mr)) { - ERR("Failed to deregister %s MR: %s", name, strerror(errno)); - } - } - mrs.clear(); - }; - - deregister_mrs(write_mr_key_list, "write key"); - deregister_mrs(write_mr_value_list, "write value"); - deregister_mrs(write_cache_key_server_mr_list, "server key"); - deregister_mrs(write_cache_value_server_mr_list, "server value"); - - // Clean up protection domain - if (g_pd) { - if (ibv_dealloc_pd(g_pd)) { - ERR("Failed to deallocate protection domain: %s", strerror(errno)); - } - g_pd = nullptr; - } - - // Close event channel - if (rdma_event_channel_epoll_fd >= 0) { - close(rdma_event_channel_epoll_fd); - rdma_event_channel_epoll_fd = -1; - } - - WARN("RDMA communicator destroyed successfully"); - } catch (const std::exception& e) { - ERR("Destruction failed: %s", e.what()); + // Clean up all connections + { + std::lock_guard lock(mutex_); + conn_map.clear(); } + + // Clean up memory regions + auto deregister_mrs = [](std::vector& mrs, const char* name) { + for (auto* mr : mrs) { + if (mr && ibv_dereg_mr(mr)) { + ERR("Failed to deregister %s MR: %s", name, strerror(errno)); + } + } + mrs.clear(); + }; + + deregister_mrs(write_mr_key_list, "write key"); + deregister_mrs(write_mr_value_list, "write value"); + deregister_mrs(write_cache_key_server_mr_list, "server key"); + deregister_mrs(write_cache_value_server_mr_list, "server value"); + + // Clean up protection domain + if (g_pd) { + if (ibv_dealloc_pd(g_pd)) { + ERR("Failed to deallocate protection domain: %s", strerror(errno)); + } + g_pd = nullptr; + } + + // Close event channel + if (rdma_event_channel_epoll_fd >= 0) { + close(rdma_event_channel_epoll_fd); + rdma_event_channel_epoll_fd = -1; + } + + WARN("RDMA communicator destroyed successfully"); + } catch (const std::exception& e) { + ERR("Destruction failed: %s", e.what()); + } } int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) { - WARN("verbs server starting …"); + WARN("verbs server starting …"); - int sockfd = setup_listening_socket(sport); - if (sockfd < 0) { - ERR("Failed to set up listening socket"); - return -1; - } - - if (g_ib_all_devs.size() == 0) { - if(parse_port_ib_info() != 0) { - ERR("decode parse_port_ib_info error, please set rdma nics info"); - return -1; - } - } - - int use_event = 1; - int epollfd = configure_epoll(sockfd); - if (epollfd < 0) { - ERR("Failed to configure epoll"); - close(sockfd); - return -1; - } - - struct epoll_event ev, events[10]; - char buffer[QpInfo::size] = {0}; - std::map connectionContexts; - std::unique_ptr rem_dest(new QpInfo()); - std::unique_ptr local_dest(new QpInfo()); - struct RdmaContext* contexts[RDMA_TCP_CONNECT_SIZE] = {nullptr}; - - while (RDMACommunicator_status == 1) { - int nfds = epoll_wait(epollfd, events, 10, -1); - if (nfds < 0) { - if (errno == EINTR) continue; - ERR("epoll_wait failed: %s", strerror(errno)); - break; - } - - for (int i = 0; i < nfds; i++) { - int event_fd = events[i].data.fd; - - if (event_fd == sockfd) { - int connfd = accept(sockfd, nullptr, nullptr); - if (connfd < 0) { - if (errno == EINTR) continue; - ERR("accept() failed: %s", strerror(errno)); - continue; - } - - if (fcntl(connfd, F_SETFL, fcntl(connfd, F_GETFL, 0) | O_NONBLOCK) < 0) { - ERR("Failed to set non-blocking mode for connfd: %s", strerror(errno)); - close(connfd); - continue; - } - - ev.events = EPOLLIN | EPOLLRDHUP | EPOLLERR; - ev.data.fd = connfd; - if (epoll_ctl(epollfd, EPOLL_CTL_ADD, connfd, &ev) < 0) { - ERR("Failed to add connfd to epoll: %s", strerror(errno)); - close(connfd); - continue; - } - - size_t dev_idx = gpu_index % g_ib_all_devs.size(); - struct IbDeviceInfo *ib_dev = &g_ib_all_devs[dev_idx]; - struct RdmaContext* ctx = create_qp( - ib_dev, &g_pd); - if (!ctx) { - ERR("Failed to initialize RDMA Context"); - close_server_connection(connfd, ctx, epollfd, connectionContexts); - continue; - } - - connectionContexts[connfd] = ctx; - ctx->conn.layer_number = layer_number; - ctx->conn.block_number = block_number; - ctx->conn.block_byte_size = block_size_byte; - ctx->conn.local_cache_key_ptr_per_layer = local_cache_key_ptr_per_layer; - ctx->conn.local_cache_value_ptr_per_layer = local_cache_value_ptr_per_layer; - - std::lock_guard lock(mutex_); - if(!server_mr_register_per_layer(ctx)){ - ERR("server_mr_register_per_layer failed"); - return -1; - } - - if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) { - close_server_connection(connfd, ctx, epollfd, connectionContexts); - ERR("Couldn't get port info"); - continue; - } - - local_dest->lid = ctx->portinfo.lid; - local_dest->mtu = ctx->portinfo.active_mtu; - if (ctx->portinfo.link_layer != IBV_LINK_LAYER_ETHERNET && !local_dest->lid) { - close_server_connection(connfd, ctx, epollfd, connectionContexts); - ERR("Couldn't get local LID"); - continue; - } - - if (sgid_idx >= 0) { - if (ibv_query_gid(ctx->context, ib_dev->port, sgid_idx, &local_dest->gid)) { - close_server_connection(connfd, ctx, epollfd, connectionContexts); - ERR("Can't read sgid of index %d", sgid_idx); - continue; - } - } else { - memset(&local_dest->gid, 0, sizeof local_dest->gid); - } - - local_dest->qpn = ctx->qp->qp_num; - - if (server_exchange_qp_info(connfd, local_dest.get(), rem_dest.get()) < 0) { - close_server_connection(connfd, ctx, epollfd, connectionContexts); - ERR("Failed to exchange QP info"); - continue; - } - - if (modify_qp_to_rts(ctx, ib_dev->port, local_dest->psn, rem_dest.get(), sgid_idx) != QpStatus::kSuccess) { - close_server_connection(connfd, ctx, epollfd, connectionContexts); - ERR("Failed to connect to remote QP"); - continue; - } - - server_exchange_mr(ctx); - } else { - auto ctx_iter = connectionContexts.find(event_fd); - if (ctx_iter == connectionContexts.end()) { - LOGD("Unknown Connection fd: %d", event_fd); - continue; - } - struct RdmaContext* ctx = ctx_iter->second; - if (events[i].events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) { - LOGD("Connection closed or error detected on fd: %d", event_fd); - close_server_connection(event_fd, ctx, epollfd, connectionContexts); - continue; - } - - if (events[i].events & EPOLLIN) { - char buffer[sizeof(QpInfo)]; - ssize_t bytes_read = read(event_fd, buffer, sizeof(buffer)); - - if (bytes_read <= 0) { - LOGD("Read error or peer closed Connection on fd %d", event_fd); - close_server_connection(event_fd, ctx, epollfd, connectionContexts); - } - } - } - } + int sockfd = setup_listening_socket(sport); + if (sockfd < 0) { + ERR("Failed to set up listening socket"); + return -1; + } + + if (g_ib_all_devs.size() == 0) { + if (parse_port_ib_info() != 0) { + ERR("decode parse_port_ib_info error, please set rdma nics info"); + return -1; } + } + int use_event = 1; + int epollfd = configure_epoll(sockfd); + if (epollfd < 0) { + ERR("Failed to configure epoll"); close(sockfd); - close(epollfd); - return 0; + return -1; + } + + struct epoll_event ev, events[10]; + char buffer[QpInfo::size] = {0}; + std::map connectionContexts; + std::unique_ptr rem_dest(new QpInfo()); + std::unique_ptr local_dest(new QpInfo()); + struct RdmaContext* contexts[RDMA_TCP_CONNECT_SIZE] = {nullptr}; + + while (RDMACommunicator_status == 1) { + int nfds = epoll_wait(epollfd, events, 10, -1); + if (nfds < 0) { + if (errno == EINTR) continue; + ERR("epoll_wait failed: %s", strerror(errno)); + break; + } + + for (int i = 0; i < nfds; i++) { + int event_fd = events[i].data.fd; + + if (event_fd == sockfd) { + int connfd = accept(sockfd, nullptr, nullptr); + if (connfd < 0) { + if (errno == EINTR) continue; + ERR("accept() failed: %s", strerror(errno)); + continue; + } + + if (fcntl(connfd, F_SETFL, fcntl(connfd, F_GETFL, 0) | O_NONBLOCK) < + 0) { + ERR("Failed to set non-blocking mode for connfd: %s", + strerror(errno)); + close(connfd); + continue; + } + + ev.events = EPOLLIN | EPOLLRDHUP | EPOLLERR; + ev.data.fd = connfd; + if (epoll_ctl(epollfd, EPOLL_CTL_ADD, connfd, &ev) < 0) { + ERR("Failed to add connfd to epoll: %s", strerror(errno)); + close(connfd); + continue; + } + + size_t dev_idx = gpu_index % g_ib_all_devs.size(); + struct IbDeviceInfo* ib_dev = &g_ib_all_devs[dev_idx]; + struct RdmaContext* ctx = create_qp(ib_dev, &g_pd); + if (!ctx) { + ERR("Failed to initialize RDMA Context"); + close_server_connection(connfd, ctx, epollfd, connectionContexts); + continue; + } + + connectionContexts[connfd] = ctx; + ctx->conn.layer_number = layer_number; + ctx->conn.block_number = block_number; + ctx->conn.block_byte_size = block_size_byte; + ctx->conn.local_cache_key_ptr_per_layer = local_cache_key_ptr_per_layer; + ctx->conn.local_cache_value_ptr_per_layer = + local_cache_value_ptr_per_layer; + + std::lock_guard lock(mutex_); + if (!server_mr_register_per_layer(ctx)) { + ERR("server_mr_register_per_layer failed"); + return -1; + } + + if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) { + close_server_connection(connfd, ctx, epollfd, connectionContexts); + ERR("Couldn't get port info"); + continue; + } + + local_dest->lid = ctx->portinfo.lid; + local_dest->mtu = ctx->portinfo.active_mtu; + if (ctx->portinfo.link_layer != IBV_LINK_LAYER_ETHERNET && + !local_dest->lid) { + close_server_connection(connfd, ctx, epollfd, connectionContexts); + ERR("Couldn't get local LID"); + continue; + } + + if (sgid_idx >= 0) { + if (ibv_query_gid( + ctx->context, ib_dev->port, sgid_idx, &local_dest->gid)) { + close_server_connection(connfd, ctx, epollfd, connectionContexts); + ERR("Can't read sgid of index %d", sgid_idx); + continue; + } + } else { + memset(&local_dest->gid, 0, sizeof local_dest->gid); + } + + local_dest->qpn = ctx->qp->qp_num; + + if (server_exchange_qp_info(connfd, local_dest.get(), rem_dest.get()) < + 0) { + close_server_connection(connfd, ctx, epollfd, connectionContexts); + ERR("Failed to exchange QP info"); + continue; + } + + if (modify_qp_to_rts( + ctx, ib_dev->port, local_dest->psn, rem_dest.get(), sgid_idx) != + QpStatus::kSuccess) { + close_server_connection(connfd, ctx, epollfd, connectionContexts); + ERR("Failed to connect to remote QP"); + continue; + } + + server_exchange_mr(ctx); + } else { + auto ctx_iter = connectionContexts.find(event_fd); + if (ctx_iter == connectionContexts.end()) { + LOGD("Unknown Connection fd: %d", event_fd); + continue; + } + struct RdmaContext* ctx = ctx_iter->second; + if (events[i].events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) { + LOGD("Connection closed or error detected on fd: %d", event_fd); + close_server_connection(event_fd, ctx, epollfd, connectionContexts); + continue; + } + + if (events[i].events & EPOLLIN) { + char buffer[sizeof(QpInfo)]; + ssize_t bytes_read = read(event_fd, buffer, sizeof(buffer)); + + if (bytes_read <= 0) { + LOGD("Read error or peer closed Connection on fd %d", event_fd); + close_server_connection(event_fd, ctx, epollfd, connectionContexts); + } + } + } + } + } + + close(sockfd); + close(epollfd); + return 0; } -void RDMACommunicator::close_server_connection(int fd, struct RdmaContext* ctx, int epollfd, std::map& connectionContexts) { - if (ctx) { - if (!deregister_memory_regions(ctx)) { - WARN("Failed to clear memory regions for Connection fd %d", fd); - } - if (!clear_qp_info(ctx)) { - WARN("Failed to clear memory regions for Connection fd %d", fd); - } - delete ctx; +void RDMACommunicator::close_server_connection( + int fd, + struct RdmaContext* ctx, + int epollfd, + std::map& connectionContexts) { + if (ctx) { + if (!deregister_memory_regions(ctx)) { + WARN("Failed to clear memory regions for Connection fd %d", fd); } - connectionContexts.erase(fd); - epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, nullptr); - close(fd); - LOGD("Connection fd %d closed and cleaned up", fd); -} - -void RDMACommunicator::close_client_connection(int fd, struct RdmaContext* ctx, int epollfd) { - if (!ctx) { - LOGD("ctx is NULL, skipping cleanup for fd %d", fd); - epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, nullptr); - close(fd); - return; - } - - conn_map.erase(ctx->conn.url); - - for (size_t i = 0; i < ctx->conn.read_bufs.size(); ++i) { - if (ctx->conn.read_mrs[i]) ibv_dereg_mr(ctx->conn.read_mrs[i]); - if (ctx->conn.read_bufs[i]) free(ctx->conn.read_bufs[i]); - } - ctx->conn.read_bufs.clear(); - ctx->conn.read_mrs.clear(); - - - ctx->conn.connected = 0; if (!clear_qp_info(ctx)) { - LOGD("Failed to clear memory regions for Connection fd %d", fd); + WARN("Failed to clear memory regions for Connection fd %d", fd); } + delete ctx; + } + connectionContexts.erase(fd); + epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, nullptr); + close(fd); + LOGD("Connection fd %d closed and cleaned up", fd); +} +void RDMACommunicator::close_client_connection(int fd, + struct RdmaContext* ctx, + int epollfd) { + if (!ctx) { + LOGD("ctx is NULL, skipping cleanup for fd %d", fd); epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, nullptr); close(fd); - delete ctx; - LOGD("Connection fd %d closed and cleaned up", fd); + return; + } + + conn_map.erase(ctx->conn.url); + + for (size_t i = 0; i < ctx->conn.read_bufs.size(); ++i) { + if (ctx->conn.read_mrs[i]) ibv_dereg_mr(ctx->conn.read_mrs[i]); + if (ctx->conn.read_bufs[i]) free(ctx->conn.read_bufs[i]); + } + ctx->conn.read_bufs.clear(); + ctx->conn.read_mrs.clear(); + + ctx->conn.connected = 0; + if (!clear_qp_info(ctx)) { + LOGD("Failed to clear memory regions for Connection fd %d", fd); + } + + epoll_ctl(epollfd, EPOLL_CTL_DEL, fd, nullptr); + close(fd); + delete ctx; + LOGD("Connection fd %d closed and cleaned up", fd); } bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) { - if (ctx == nullptr) { - ERR("Context is null, cannot clear server Connection."); - return false; - } + if (ctx == nullptr) { + ERR("Context is null, cannot clear server Connection."); + return false; + } - for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { - if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) { - if (ibv_dereg_mr(write_mr_key_list[layer_idx])) { - ERR("Failed to deregister memory region: write_mr_key_list, layer %d", layer_idx); - } - if (ibv_dereg_mr(write_mr_value_list[layer_idx])) { - ERR("Failed to deregister memory region: write_mr_value_list, layer %d", layer_idx); - } - } + for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) { + if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) { + if (ibv_dereg_mr(write_mr_key_list[layer_idx])) { + ERR("Failed to deregister memory region: write_mr_key_list, layer %d", + layer_idx); + } + if (ibv_dereg_mr(write_mr_value_list[layer_idx])) { + ERR("Failed to deregister memory region: write_mr_value_list, layer %d", + layer_idx); + } } - return true; + } + return true; } /** @@ -439,26 +452,25 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) { * @return Result code: 0 on success, negative value on failure */ int RDMACommunicator::init_server() { - WARN("Initializing RDMA server..."); - return start_server( - KVCacheConfig::getInstance().resolve_rdma_dest_port(port), - KVCacheConfig::getInstance().get_rdma_gid_index(), - gpu_idx - ); + WARN("Initializing RDMA server..."); + return start_server(KVCacheConfig::getInstance().resolve_rdma_dest_port(port), + KVCacheConfig::getInstance().get_rdma_gid_index(), + gpu_idx); } /** * Fetch the local IP address from the main IP list * - * @return The first IP address in the main IP list, or empty string if list is empty + * @return The first IP address in the main IP list, or empty string if list is + * empty */ std::string RDMACommunicator::fetch_local_ip() { - if (main_ip_list.empty()) { - ERR("Error: main_ip_list are empty."); - return nullptr; - } + if (main_ip_list.empty()) { + ERR("Error: main_ip_list are empty."); + return nullptr; + } - return main_ip_list[0]; + return main_ip_list[0]; } /** @@ -471,182 +483,191 @@ std::string RDMACommunicator::fetch_local_ip() { * @return ConnStatus::kConnected ConnStatus::kError; */ -int RDMACommunicator::connect(const std::string &dst_ip, - const std::string &dst_port) { - std::string url = dst_ip + ":" + dst_port; +int RDMACommunicator::connect(const std::string& dst_ip, + const std::string& dst_port) { + std::string url = dst_ip + ":" + dst_port; - // Initialize IB devices if not already done - if (g_ib_all_devs.size() == 0) { - if(parse_port_ib_info() != 0) { - ERR("prefill parse_port_ib_info is error, please set rdma nics info"); - return static_cast(ConnStatus::kInvalidParameters); - } + // Initialize IB devices if not already done + if (g_ib_all_devs.size() == 0) { + if (parse_port_ib_info() != 0) { + ERR("prefill parse_port_ib_info is error, please set rdma nics info"); + return static_cast(ConnStatus::kInvalidParameters); } + } - // Check if already connected - if (is_connected(dst_ip, dst_port)) { - INFO("Already connected to %s:%s", dst_ip.c_str(), dst_port.c_str()); - return static_cast(ConnStatus::kConnected); - } - - // Create Queue Pair (QP) for the connection - size_t dev_idx = gpu_idx % g_ib_all_devs.size(); - struct IbDeviceInfo *ib_dev = &g_ib_all_devs[dev_idx]; - struct RdmaContext *ctx = create_qp(ib_dev, &g_pd); - if (!ctx) { - ERR("Couldn't create QP"); - return static_cast(ConnStatus::kError); - } - - // Initialize connection data - ctx->conn.url = url; - ctx->conn.layer_number = layer_number; - ctx->conn.block_number = block_number; - ctx->conn.block_byte_size = block_size_byte; - - // Get port information for the connection - if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) { - ERR("Couldn't get port info"); - return static_cast(ConnStatus::kError); - } - // Register memory regions - if(!client_mr_register_per_layer(ctx)){ - ERR("server_mr_register_per_layer failed"); - return static_cast(ConnStatus::kError); - } - - // Exchange connection information with remote peer - if (!client_exchange_destinations(ctx, ib_dev->port, KVCacheConfig::getInstance().resolve_rdma_dest_port(dst_port), - KVCacheConfig::getInstance().get_rdma_gid_index(), dst_ip)) { - ERR("Couldn't getexchange port infodestinations"); - return static_cast(ConnStatus::kError); - } else { - std::lock_guard lock(mutex_); - ctx->conn.connected = 1; - conn_map[url] = ctx; - client_exchange_mr(ctx); - } - - // Allocate RDMA read and register read buffers - ctx->conn.read_bufs.resize(block_number, nullptr); - ctx->conn.read_mrs.resize(block_number, nullptr); - - for (size_t i = 0; i < block_number; ++i) { - // Allocate memory for read buffer - ctx->conn.read_bufs[i] = malloc(block_size_byte); - if (!ctx->conn.read_bufs[i]) { - ERR("Failed to allocate read buffer"); - return static_cast(ConnStatus::kError); - } - // Register memory region for read buffer - ctx->conn.read_mrs[i] = ibv_reg_mr(ctx->pd, ctx->conn.read_bufs[i], block_size_byte, IBV_ACCESS_LOCAL_WRITE); - if (!ctx->conn.read_mrs[i]) { - ERR("Failed to register memory for RDMA Read buffer"); - return static_cast(ConnStatus::kError); - } - } - - // Start client listener thread if not already started - if (start_client_listener == false) { - std::thread client_thread = std::thread([this]() { - this->client_listener(); - }); - if (client_thread.joinable()) { - client_thread.detach(); - std::lock_guard lock(mutex_); - } - start_client_listener = true; - } - - // Add socket to epoll for event monitoring - if (ctx->sock_fd != 0) { - struct epoll_event ev; - ev.events = EPOLLIN | EPOLLOUT | EPOLLERR; - ev.data.ptr = ctx; - int ret = epoll_ctl(rdma_event_channel_epoll_fd, EPOLL_CTL_ADD, ctx->sock_fd, &ev); - if (ret != 0) { - ERR("failed to add event channel %d", ret); - return static_cast(ConnStatus::kError); - } - } - - WARN("connect end ...."); + // Check if already connected + if (is_connected(dst_ip, dst_port)) { + INFO("Already connected to %s:%s", dst_ip.c_str(), dst_port.c_str()); return static_cast(ConnStatus::kConnected); + } + + // Create Queue Pair (QP) for the connection + size_t dev_idx = gpu_idx % g_ib_all_devs.size(); + struct IbDeviceInfo* ib_dev = &g_ib_all_devs[dev_idx]; + struct RdmaContext* ctx = create_qp(ib_dev, &g_pd); + if (!ctx) { + ERR("Couldn't create QP"); + return static_cast(ConnStatus::kError); + } + + // Initialize connection data + ctx->conn.url = url; + ctx->conn.layer_number = layer_number; + ctx->conn.block_number = block_number; + ctx->conn.block_byte_size = block_size_byte; + + // Get port information for the connection + if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) { + ERR("Couldn't get port info"); + return static_cast(ConnStatus::kError); + } + // Register memory regions + if (!client_mr_register_per_layer(ctx)) { + ERR("server_mr_register_per_layer failed"); + return static_cast(ConnStatus::kError); + } + + // Exchange connection information with remote peer + if (!client_exchange_destinations( + ctx, + ib_dev->port, + KVCacheConfig::getInstance().resolve_rdma_dest_port(dst_port), + KVCacheConfig::getInstance().get_rdma_gid_index(), + dst_ip)) { + ERR("Couldn't getexchange port infodestinations"); + return static_cast(ConnStatus::kError); + } else { + std::lock_guard lock(mutex_); + ctx->conn.connected = 1; + conn_map[url] = ctx; + client_exchange_mr(ctx); + } + + // Allocate RDMA read and register read buffers + ctx->conn.read_bufs.resize(block_number, nullptr); + ctx->conn.read_mrs.resize(block_number, nullptr); + + for (size_t i = 0; i < block_number; ++i) { + // Allocate memory for read buffer + ctx->conn.read_bufs[i] = malloc(block_size_byte); + if (!ctx->conn.read_bufs[i]) { + ERR("Failed to allocate read buffer"); + return static_cast(ConnStatus::kError); + } + // Register memory region for read buffer + ctx->conn.read_mrs[i] = ibv_reg_mr(ctx->pd, + ctx->conn.read_bufs[i], + block_size_byte, + IBV_ACCESS_LOCAL_WRITE); + if (!ctx->conn.read_mrs[i]) { + ERR("Failed to register memory for RDMA Read buffer"); + return static_cast(ConnStatus::kError); + } + } + + // Start client listener thread if not already started + if (start_client_listener == false) { + std::thread client_thread = + std::thread([this]() { this->client_listener(); }); + if (client_thread.joinable()) { + client_thread.detach(); + std::lock_guard lock(mutex_); + } + start_client_listener = true; + } + + // Add socket to epoll for event monitoring + if (ctx->sock_fd != 0) { + struct epoll_event ev; + ev.events = EPOLLIN | EPOLLOUT | EPOLLERR; + ev.data.ptr = ctx; + int ret = epoll_ctl( + rdma_event_channel_epoll_fd, EPOLL_CTL_ADD, ctx->sock_fd, &ev); + if (ret != 0) { + ERR("failed to add event channel %d", ret); + return static_cast(ConnStatus::kError); + } + } + + WARN("connect end ...."); + return static_cast(ConnStatus::kConnected); } int RDMACommunicator::client_listener() { - struct epoll_event events[10]; + struct epoll_event events[10]; - while (RDMACommunicator_status == 1) { - int nfds = epoll_wait(rdma_event_channel_epoll_fd, events, 10, -1); - if (nfds < 0) { - if (errno == EINTR) { - WARN("epoll_wait interrupted, continuing..."); - continue; - } - ERR("epoll_wait failed: %s", strerror(errno)); - return -1; - } - - for (int i = 0; i < nfds; ++i) { - RdmaContext* ctx = static_cast(events[i].data.ptr); - if (!ctx) { - ERR("Null context received in epoll event"); - continue; - } - - if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { - int err = 0; - socklen_t len = sizeof(err); - getsockopt(ctx->sock_fd, SOL_SOCKET, SO_ERROR, &err, &len); - if (err) ERR("Socket error: %s", strerror(err)); - - std::lock_guard lock(mutex_); - close_client_connection(ctx->sock_fd, ctx, rdma_event_channel_epoll_fd); - continue; - } - - if (events[i].events & EPOLLIN) { - char buffer[sizeof(QpInfo)]; - ssize_t bytes_read = read(ctx->sock_fd, buffer, sizeof(buffer)); - - if (bytes_read <= 0) { - if (bytes_read == 0) { - WARN("Peer closed connection on fd %d", ctx->sock_fd); - } else { - ERR("Read error on fd %d: %s", ctx->sock_fd, strerror(errno)); - } - - std::lock_guard lock(mutex_); - close_client_connection(ctx->sock_fd, ctx, rdma_event_channel_epoll_fd); - } - } - } + while (RDMACommunicator_status == 1) { + int nfds = epoll_wait(rdma_event_channel_epoll_fd, events, 10, -1); + if (nfds < 0) { + if (errno == EINTR) { + WARN("epoll_wait interrupted, continuing..."); + continue; + } + ERR("epoll_wait failed: %s", strerror(errno)); + return -1; } - return 0; + for (int i = 0; i < nfds; ++i) { + RdmaContext* ctx = static_cast(events[i].data.ptr); + if (!ctx) { + ERR("Null context received in epoll event"); + continue; + } + + if (events[i].events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { + int err = 0; + socklen_t len = sizeof(err); + getsockopt(ctx->sock_fd, SOL_SOCKET, SO_ERROR, &err, &len); + if (err) ERR("Socket error: %s", strerror(err)); + + std::lock_guard lock(mutex_); + close_client_connection(ctx->sock_fd, ctx, rdma_event_channel_epoll_fd); + continue; + } + + if (events[i].events & EPOLLIN) { + char buffer[sizeof(QpInfo)]; + ssize_t bytes_read = read(ctx->sock_fd, buffer, sizeof(buffer)); + + if (bytes_read <= 0) { + if (bytes_read == 0) { + WARN("Peer closed connection on fd %d", ctx->sock_fd); + } else { + ERR("Read error on fd %d: %s", ctx->sock_fd, strerror(errno)); + } + + std::lock_guard lock(mutex_); + close_client_connection( + ctx->sock_fd, ctx, rdma_event_channel_epoll_fd); + } + } + } + } + + return 0; } -bool RDMACommunicator::is_connected(const std::string &dst_ip, const std::string &dst_port) { - std::string url = dst_ip + ":" + dst_port; - return conn_map.find(url) != conn_map.end(); +bool RDMACommunicator::is_connected(const std::string& dst_ip, + const std::string& dst_port) { + std::string url = dst_ip + ":" + dst_port; + return conn_map.find(url) != conn_map.end(); } void RDMACommunicator::remove_conn(const std::string& url) { - if (conn_map.find(url) != conn_map.end()) { - struct RdmaContext * ctx = conn_map[url]; - conn_map.erase(url); - free(ctx->context); - } + if (conn_map.find(url) != conn_map.end()) { + struct RdmaContext* ctx = conn_map[url]; + conn_map.erase(url); + free(ctx->context); + } } -struct RdmaContext *RDMACommunicator::get_conn(const std::string &ip, - const std::string &port) { - std::string url = ip + ":" + port; - if (conn_map.find(url) == conn_map.end()) { - return NULL; - } - return conn_map[url]; +struct RdmaContext* RDMACommunicator::get_conn(const std::string& ip, + const std::string& port) { + std::string url = ip + ":" + port; + if (conn_map.find(url) == conn_map.end()) { + return NULL; + } + return conn_map[url]; } /** @@ -659,29 +680,34 @@ struct RdmaContext *RDMACommunicator::get_conn(const std::string &ip, * @return Pointer to the registered memory region on success * @throws std::runtime_error Throws an exception if registration fails */ -struct ibv_mr* RDMACommunicator::register_memory_region( - ibv_pd* pd, void* addr, size_t size, - const std::string& desc, uint32_t access_flags) { +struct ibv_mr* RDMACommunicator::register_memory_region(ibv_pd* pd, + void* addr, + size_t size, + const std::string& desc, + uint32_t access_flags) { + if (!pd || !addr || size == 0) { + throw std::invalid_argument("Invalid memory region parameters"); + } - if (!pd || !addr || size == 0) { - throw std::invalid_argument("Invalid memory region parameters"); - } + // Check and set the Relaxed Ordering flag + if (KVCacheConfig::getInstance().is_relax_ordering_enabled()) { + access_flags |= IBV_ACCESS_RELAXED_ORDERING; + LOGD("Enabled Relaxed Ordering for %s", desc.c_str()); + } - // Check and set the Relaxed Ordering flag - if (KVCacheConfig::getInstance().is_relax_ordering_enabled()) { - access_flags |= IBV_ACCESS_RELAXED_ORDERING; - LOGD("Enabled Relaxed Ordering for %s", desc.c_str()); - } + struct ibv_mr* mr = ibv_reg_mr(pd, addr, size, access_flags); + if (!mr) { + throw std::runtime_error("Failed to register memory region " + desc + ": " + + strerror(errno)); + } - struct ibv_mr* mr = ibv_reg_mr(pd, addr, size, access_flags); - if (!mr) { - throw std::runtime_error("Failed to register memory region " + desc + - ": " + strerror(errno)); - } - - LOGD("Registered %s MR: addr=%p, size=%zu, flags=0x%x, lkey=0x%x", - desc.c_str(), addr, size, access_flags, mr->lkey); - return mr; + LOGD("Registered %s MR: addr=%p, size=%zu, flags=0x%x, lkey=0x%x", + desc.c_str(), + addr, + size, + access_flags, + mr->lkey); + return mr; } /** @@ -689,58 +715,69 @@ struct ibv_mr* RDMACommunicator::register_memory_region( * @param ctx Pointer to the RDMA context * @note This method registers memory regions for the KV cache of each layer */ -bool RDMACommunicator::client_mr_register_per_layer(RdmaContext *ctx) { - if (!ctx || !ctx->pd) { - ERR("Invalid RDMA context"); - return false; - } +bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) { + if (!ctx || !ctx->pd) { + ERR("Invalid RDMA context"); + return false; + } - std::lock_guard lock(mutex_); - - if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) { - WARN("Memory regions already registered"); - return true; - } - - const size_t list_size = layer_number; - write_mr_key_list.resize(list_size, nullptr); - write_mr_value_list.resize(list_size, nullptr); - - const uint32_t access_flags = IBV_ACCESS_LOCAL_WRITE | - (KVCacheConfig::getInstance().is_relax_ordering_enabled() ? IBV_ACCESS_RELAXED_ORDERING : 0); - - for (int i = 0; i < static_cast(list_size); ++i) { - void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); - void* val_ptr = reinterpret_cast(local_cache_value_ptr_layer_head_[i]); - size_t size = static_cast(block_size_byte) * block_number; - - write_mr_key_list[i] = register_memory_region(ctx->pd, key_ptr, size, - "client_key_" + std::to_string(i), access_flags); - if (!write_mr_key_list[i]) goto fail; - - write_mr_value_list[i] = register_memory_region(ctx->pd, val_ptr, size, - "client_value_" + std::to_string(i), access_flags); - if (!write_mr_value_list[i]) goto fail; - } + std::lock_guard lock(mutex_); + if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) { + WARN("Memory regions already registered"); return true; + } + + const size_t list_size = layer_number; + write_mr_key_list.resize(list_size, nullptr); + write_mr_value_list.resize(list_size, nullptr); + + const uint32_t access_flags = + IBV_ACCESS_LOCAL_WRITE | + (KVCacheConfig::getInstance().is_relax_ordering_enabled() + ? IBV_ACCESS_RELAXED_ORDERING + : 0); + + for (int i = 0; i < static_cast(list_size); ++i) { + void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); + void* val_ptr = + reinterpret_cast(local_cache_value_ptr_layer_head_[i]); + size_t size = static_cast(block_size_byte) * block_number; + + write_mr_key_list[i] = + register_memory_region(ctx->pd, + key_ptr, + size, + "client_key_" + std::to_string(i), + access_flags); + if (!write_mr_key_list[i]) goto fail; + + write_mr_value_list[i] = + register_memory_region(ctx->pd, + val_ptr, + size, + "client_value_" + std::to_string(i), + access_flags); + if (!write_mr_value_list[i]) goto fail; + } + + return true; fail: - ERR("Memory region registration failed. Cleaning up..."); + ERR("Memory region registration failed. Cleaning up..."); - for (auto* mr : write_mr_key_list) { - if (mr) ibv_dereg_mr(mr); - } - for (auto* mr : write_mr_value_list) { - if (mr) ibv_dereg_mr(mr); - } + for (auto* mr : write_mr_key_list) { + if (mr) ibv_dereg_mr(mr); + } + for (auto* mr : write_mr_value_list) { + if (mr) ibv_dereg_mr(mr); + } - write_mr_key_list.clear(); - write_mr_value_list.clear(); - return false; + write_mr_key_list.clear(); + write_mr_value_list.clear(); + return false; } - /** * @brief Register server-side memory regions for RDMA operations * @param ctx RDMA context containing protection domain and other resources @@ -748,309 +785,362 @@ fail: * @details This method registers memory regions for both keys and values * for each layer, enabling remote read/write access. */ -bool RDMACommunicator::server_mr_register_per_layer(RdmaContext *ctx) { - if (!ctx || !ctx->pd) { - ERR("Invalid RDMA context"); - return false; +bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) { + if (!ctx || !ctx->pd) { + ERR("Invalid RDMA context"); + return false; + } + + write_cache_key_server_mr_list.clear(); + write_cache_value_server_mr_list.clear(); + + const uint32_t access_flags = + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + + for (int i = 0; i < layer_number; ++i) { + void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); + void* val_ptr = + reinterpret_cast(local_cache_value_ptr_layer_head_[i]); + size_t size = static_cast(block_size_byte) * block_number; + + struct ibv_mr* key_mr = register_memory_region( + ctx->pd, key_ptr, size, "key_" + std::to_string(i), access_flags); + if (!key_mr) { + ERR("Failed to register key MR at layer %d", i); + goto fail; } - write_cache_key_server_mr_list.clear(); - write_cache_value_server_mr_list.clear(); - - const uint32_t access_flags = IBV_ACCESS_LOCAL_WRITE | - IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ; - - for (int i = 0; i < layer_number; ++i) { - void* key_ptr = reinterpret_cast(local_cache_key_ptr_layer_head_[i]); - void* val_ptr = reinterpret_cast(local_cache_value_ptr_layer_head_[i]); - size_t size = static_cast(block_size_byte) * block_number; - - struct ibv_mr* key_mr = register_memory_region(ctx->pd, key_ptr, size, "key_" + std::to_string(i), access_flags); - if (!key_mr) { - ERR("Failed to register key MR at layer %d", i); - goto fail; - } - - struct ibv_mr* value_mr = register_memory_region(ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags); - if (!value_mr) { - ERR("Failed to register value MR at layer %d", i); - ibv_dereg_mr(key_mr); - goto fail; - } - - write_cache_key_server_mr_list.push_back(key_mr); - write_cache_value_server_mr_list.push_back(value_mr); + struct ibv_mr* value_mr = register_memory_region( + ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags); + if (!value_mr) { + ERR("Failed to register value MR at layer %d", i); + ibv_dereg_mr(key_mr); + goto fail; } - ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list; - ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list; - return true; + write_cache_key_server_mr_list.push_back(key_mr); + write_cache_value_server_mr_list.push_back(value_mr); + } + + ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list; + ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list; + return true; fail: - for (auto* mr : write_cache_key_server_mr_list) { - if (mr) ibv_dereg_mr(mr); - } - for (auto* mr : write_cache_value_server_mr_list) { - if (mr) ibv_dereg_mr(mr); - } + for (auto* mr : write_cache_key_server_mr_list) { + if (mr) ibv_dereg_mr(mr); + } + for (auto* mr : write_cache_value_server_mr_list) { + if (mr) ibv_dereg_mr(mr); + } - write_cache_key_server_mr_list.clear(); - write_cache_value_server_mr_list.clear(); - return false; + write_cache_key_server_mr_list.clear(); + write_cache_value_server_mr_list.clear(); + return false; } -int RDMACommunicator::write_cache(const std::string &ip, - const std::string &port, - const std::vector& local_block_ids, - const std::vector& remote_block_ids, - int32_t layer_idx) { - // Parameter validation - if (local_block_ids.size() != remote_block_ids.size()) { - ERR("Block ID lists size mismatch: local=%zu, remote=%zu", - local_block_ids.size(), remote_block_ids.size()); - return -1; - } +int RDMACommunicator::write_cache(const std::string& ip, + const std::string& port, + const std::vector& local_block_ids, + const std::vector& remote_block_ids, + int32_t layer_idx) { + // Parameter validation + if (local_block_ids.size() != remote_block_ids.size()) { + ERR("Block ID lists size mismatch: local=%zu, remote=%zu", + local_block_ids.size(), + remote_block_ids.size()); + return -1; + } - if (layer_idx < 0 || layer_idx >= layer_number) { - ERR("Invalid layer index: %d (max: %d)", layer_idx, layer_number - 1); - return -1; - } + if (layer_idx < 0 || layer_idx >= layer_number) { + ERR("Invalid layer index: %d (max: %d)", layer_idx, layer_number - 1); + return -1; + } - const auto block_num = local_block_ids.size(); - if (block_num == 0) { - WARN("Empty block list, nothing to write"); - return 0; - } - - // Performance debugging - std::chrono::steady_clock::time_point start_time; - if (KVCacheConfig::getInstance().is_debug_mode_enabled()) { - start_time = std::chrono::steady_clock::now(); - } - - // Get connection context with thread safety - std::unique_lock lock(mutex_); - auto* ctx = get_conn(ip, port); - lock.unlock(); - - if (!ctx || !ctx->conn.connected) { - ERR("No active connection to %s:%s", ip.c_str(), port.c_str()); - return -1; - } - - std::vector cache_key_remote_addr(block_num); - std::vector cache_value_remote_addr(block_num); - std::vector crc_cache_key_remote_addr(block_num); - std::vector crc_cache_value_remote_addr(block_num); - - uint32_t cache_key_rkey = ctx->conn.write_cache_key_remote_rkey_list[layer_idx]; - uint32_t cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; - uint32_t crc_cache_key_rkey, crc_cache_value_rkey; - - for (size_t block_index = 0; block_index < block_num; ++block_index) { - char* char_ptr = static_cast(ctx->conn.write_cache_key_remote_ptr_list[layer_idx]); - cache_key_remote_addr[block_index] = - (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); - char_ptr = static_cast(ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); - cache_value_remote_addr[block_index] = - (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); - } - ctx->conn.wc_target_count = 0; - for (int i = 0; i < 2; ++i) { - bool is_key = (i == 0); - uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey); - std::vector& remote_addr = (is_key ? cache_key_remote_addr : cache_value_remote_addr); - if (!post_block_send(ctx, layer_idx, local_block_ids, is_key, remote_addr, rkey, ip, port)) { - return -1; - } - } - - if (KVCacheConfig::getInstance().is_debug_mode_enabled()) { - auto duration_us = std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time).count(); - - DEBUG("Write cache completed - IP: %s, Port: %s, Layer: %d, BlockSize: %d, Blocks: %lu, Duration: %ld us", - ip.c_str(), port.c_str(), layer_idx, block_size_byte, block_num, duration_us); - } + const auto block_num = local_block_ids.size(); + if (block_num == 0) { + WARN("Empty block list, nothing to write"); return 0; + } + + // Performance debugging + std::chrono::steady_clock::time_point start_time; + if (KVCacheConfig::getInstance().is_debug_mode_enabled()) { + start_time = std::chrono::steady_clock::now(); + } + + // Get connection context with thread safety + std::unique_lock lock(mutex_); + auto* ctx = get_conn(ip, port); + lock.unlock(); + + if (!ctx || !ctx->conn.connected) { + ERR("No active connection to %s:%s", ip.c_str(), port.c_str()); + return -1; + } + + std::vector cache_key_remote_addr(block_num); + std::vector cache_value_remote_addr(block_num); + std::vector crc_cache_key_remote_addr(block_num); + std::vector crc_cache_value_remote_addr(block_num); + + uint32_t cache_key_rkey = + ctx->conn.write_cache_key_remote_rkey_list[layer_idx]; + uint32_t cache_value_rkey = + ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; + uint32_t crc_cache_key_rkey, crc_cache_value_rkey; + + for (size_t block_index = 0; block_index < block_num; ++block_index) { + char* char_ptr = static_cast( + ctx->conn.write_cache_key_remote_ptr_list[layer_idx]); + cache_key_remote_addr[block_index] = + (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); + char_ptr = static_cast( + ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); + cache_value_remote_addr[block_index] = + (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); + } + ctx->conn.wc_target_count = 0; + for (int i = 0; i < 2; ++i) { + bool is_key = (i == 0); + uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey); + std::vector& remote_addr = + (is_key ? cache_key_remote_addr : cache_value_remote_addr); + if (!post_block_send(ctx, + layer_idx, + local_block_ids, + is_key, + remote_addr, + rkey, + ip, + port)) { + return -1; + } + } + + if (KVCacheConfig::getInstance().is_debug_mode_enabled()) { + auto duration_us = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time) + .count(); + + DEBUG( + "Write cache completed - IP: %s, Port: %s, Layer: %d, BlockSize: %d, " + "Blocks: %lu, Duration: %ld us", + ip.c_str(), + port.c_str(), + layer_idx, + block_size_byte, + block_num, + duration_us); + } + return 0; } -bool RDMACommunicator::post_block_send(struct RdmaContext* ctx, int layer_idx, - const std::vector& local_block_ids, - bool is_key, std::vector& remote_addr, - uint32_t rkey, const std::string &ip, - const std::string &port) { - auto block_num = local_block_ids.size(); - assert(block_num > 0 && "block_num must be > 0"); +bool RDMACommunicator::post_block_send( + struct RdmaContext* ctx, + int layer_idx, + const std::vector& local_block_ids, + bool is_key, + std::vector& remote_addr, + uint32_t rkey, + const std::string& ip, + const std::string& port) { + auto block_num = local_block_ids.size(); + assert(block_num > 0 && "block_num must be > 0"); - bool success = execute_rdma_writes(ctx, layer_idx, local_block_ids, - is_key, remote_addr, rkey); + bool success = execute_rdma_writes( + ctx, layer_idx, local_block_ids, is_key, remote_addr, rkey); - if (success) { - if (KVCacheConfig::getInstance().is_gdrcopy_flush_enabled()) { - const size_t last_idx = block_num - 1; - success = execute_read_verification(ctx, last_idx, remote_addr[last_idx], - rkey, layer_idx, ip, port); - } + if (success) { + if (KVCacheConfig::getInstance().is_gdrcopy_flush_enabled()) { + const size_t last_idx = block_num - 1; + success = execute_read_verification( + ctx, last_idx, remote_addr[last_idx], rkey, layer_idx, ip, port); } + } - return success; + return success; } -bool RDMACommunicator::execute_rdma_writes(struct RdmaContext* ctx, int layer_idx, - const std::vector& local_block_ids, - bool is_key, std::vector& remote_addr, - uint32_t rkey) { - auto block_num = local_block_ids.size(); - struct ibv_sge* sge_list = new ibv_sge[block_num]; - struct ibv_send_wr* send_wr_list = new ibv_send_wr[block_num]; +bool RDMACommunicator::execute_rdma_writes( + struct RdmaContext* ctx, + int layer_idx, + const std::vector& local_block_ids, + bool is_key, + std::vector& remote_addr, + uint32_t rkey) { + auto block_num = local_block_ids.size(); + struct ibv_sge* sge_list = new ibv_sge[block_num]; + struct ibv_send_wr* send_wr_list = new ibv_send_wr[block_num]; - prepare_write_requests(sge_list, send_wr_list, layer_idx, - local_block_ids, is_key, remote_addr, rkey); + prepare_write_requests(sge_list, + send_wr_list, + layer_idx, + local_block_ids, + is_key, + remote_addr, + rkey); - bool success = true; - size_t inflight_wr = 0; + bool success = true; + size_t inflight_wr = 0; - for (size_t scnt = 0; scnt < block_num; ++scnt) { - size_t idx = scnt % RDMA_WR_LIST_MAX_SIZE; - inflight_wr++; + for (size_t scnt = 0; scnt < block_num; ++scnt) { + size_t idx = scnt % RDMA_WR_LIST_MAX_SIZE; + inflight_wr++; - bool is_batch_end = (idx == RDMA_WR_LIST_MAX_SIZE - 1 || scnt == block_num - 1); - bool need_poll = (inflight_wr >= RDMA_SQ_MAX_SIZE || scnt == block_num - 1); + bool is_batch_end = + (idx == RDMA_WR_LIST_MAX_SIZE - 1 || scnt == block_num - 1); + bool need_poll = (inflight_wr >= RDMA_SQ_MAX_SIZE || scnt == block_num - 1); - if (is_batch_end) { - if (!post_send_with_retry(ctx, &send_wr_list[scnt - idx], - inflight_wr, need_poll)) { - success = false; - break; - } - if (need_poll) { - inflight_wr = 0; - } - } + if (is_batch_end) { + if (!post_send_with_retry( + ctx, &send_wr_list[scnt - idx], inflight_wr, need_poll)) { + success = false; + break; + } + if (need_poll) { + inflight_wr = 0; + } } + } - delete[] sge_list; - delete[] send_wr_list; - return success; + delete[] sge_list; + delete[] send_wr_list; + return success; } -void RDMACommunicator::prepare_write_requests(struct ibv_sge* sge_list, - struct ibv_send_wr* send_wr_list, - int layer_idx, - const std::vector& local_block_ids, - bool is_key, - std::vector& remote_addr, - uint32_t rkey) { - auto block_num = local_block_ids.size(); +void RDMACommunicator::prepare_write_requests( + struct ibv_sge* sge_list, + struct ibv_send_wr* send_wr_list, + int layer_idx, + const std::vector& local_block_ids, + bool is_key, + std::vector& remote_addr, + uint32_t rkey) { + auto block_num = local_block_ids.size(); - for (size_t i = 0; i < block_num; ++i) { - sge_list[i].addr = (uintptr_t)(is_key ? - local_cache_key_ptr_per_layer[layer_idx][local_block_ids[i]] : - local_cache_value_ptr_per_layer[layer_idx][local_block_ids[i]]); - sge_list[i].length = block_size_byte; - sge_list[i].lkey = (is_key ? - write_mr_key_list[layer_idx]->lkey : - write_mr_value_list[layer_idx]->lkey); + for (size_t i = 0; i < block_num; ++i) { + sge_list[i].addr = + (uintptr_t)(is_key + ? local_cache_key_ptr_per_layer[layer_idx] + [local_block_ids[i]] + : local_cache_value_ptr_per_layer[layer_idx] + [local_block_ids[i]]); + sge_list[i].length = block_size_byte; + sge_list[i].lkey = (is_key ? write_mr_key_list[layer_idx]->lkey + : write_mr_value_list[layer_idx]->lkey); - size_t idx = i % RDMA_WR_LIST_MAX_SIZE; - send_wr_list[i].wr_id = i; - send_wr_list[i].next = (idx == RDMA_WR_LIST_MAX_SIZE - 1 || i == block_num - 1) ? - nullptr : &send_wr_list[i + 1]; - send_wr_list[i].sg_list = &sge_list[i]; - send_wr_list[i].num_sge = 1; - send_wr_list[i].opcode = IBV_WR_RDMA_WRITE; - send_wr_list[i].send_flags = (i == block_num - 1) ? IBV_SEND_SIGNALED : 0; - send_wr_list[i].wr.rdma.remote_addr = remote_addr[i]; - send_wr_list[i].wr.rdma.rkey = rkey; - } + size_t idx = i % RDMA_WR_LIST_MAX_SIZE; + send_wr_list[i].wr_id = i; + send_wr_list[i].next = + (idx == RDMA_WR_LIST_MAX_SIZE - 1 || i == block_num - 1) + ? nullptr + : &send_wr_list[i + 1]; + send_wr_list[i].sg_list = &sge_list[i]; + send_wr_list[i].num_sge = 1; + send_wr_list[i].opcode = IBV_WR_RDMA_WRITE; + send_wr_list[i].send_flags = (i == block_num - 1) ? IBV_SEND_SIGNALED : 0; + send_wr_list[i].wr.rdma.remote_addr = remote_addr[i]; + send_wr_list[i].wr.rdma.rkey = rkey; + } } bool RDMACommunicator::post_send_with_retry(struct RdmaContext* ctx, - struct ibv_send_wr* wr_list, - size_t inflight_wr, - bool need_poll) { - const int max_retries = 7; - int retries = 0; - int ret = 0; - struct ibv_send_wr* bad_wr = nullptr; + struct ibv_send_wr* wr_list, + size_t inflight_wr, + bool need_poll) { + const int max_retries = 7; + int retries = 0; + int ret = 0; + struct ibv_send_wr* bad_wr = nullptr; - if (inflight_wr >= RDMA_SQ_MAX_SIZE && wr_list) { - struct ibv_send_wr* last_wr = wr_list; - while (last_wr->next) { - last_wr = last_wr->next; - } - last_wr->send_flags |= IBV_SEND_SIGNALED; + if (inflight_wr >= RDMA_SQ_MAX_SIZE && wr_list) { + struct ibv_send_wr* last_wr = wr_list; + while (last_wr->next) { + last_wr = last_wr->next; } + last_wr->send_flags |= IBV_SEND_SIGNALED; + } - do { - ret = ibv_post_send(ctx->qp, wr_list, &bad_wr); - if (ret == 0) { - if (need_poll) { - ctx->conn.wc_count = 0; - ctx->conn.wc_target_count = 0; - if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { - ERR("Polling CQ failed after RDMA Write"); - return false; - } - } - return true; - } else { - ERR("ibv_post_send failed: %s (errno: %d), retry %d/%d", - strerror(errno), errno, retries + 1, max_retries); - usleep(1000); - retries++; + do { + ret = ibv_post_send(ctx->qp, wr_list, &bad_wr); + if (ret == 0) { + if (need_poll) { + ctx->conn.wc_count = 0; + ctx->conn.wc_target_count = 0; + if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { + ERR("Polling CQ failed after RDMA Write"); + return false; } - } while (retries < max_retries); + } + return true; + } else { + ERR("ibv_post_send failed: %s (errno: %d), retry %d/%d", + strerror(errno), + errno, + retries + 1, + max_retries); + usleep(1000); + retries++; + } + } while (retries < max_retries); - ERR("ibv_post_send failed after %d retries: %s (errno: %d)", - retries, strerror(errno), errno); - return false; + ERR("ibv_post_send failed after %d retries: %s (errno: %d)", + retries, + strerror(errno), + errno); + return false; } bool RDMACommunicator::execute_read_verification(struct RdmaContext* ctx, - size_t block_idx, - uint64_t remote_addr, - uint32_t rkey, - int layer_idx, - const std::string& ip, - const std::string& port) { - ibv_sge read_sge = { - .addr = reinterpret_cast(ctx->conn.read_bufs[block_idx]), - .length = static_cast(block_size_byte), - .lkey = ctx->conn.read_mrs[block_idx]->lkey - }; + size_t block_idx, + uint64_t remote_addr, + uint32_t rkey, + int layer_idx, + const std::string& ip, + const std::string& port) { + ibv_sge read_sge = { + .addr = reinterpret_cast(ctx->conn.read_bufs[block_idx]), + .length = static_cast(block_size_byte), + .lkey = ctx->conn.read_mrs[block_idx]->lkey}; - ibv_send_wr read_wr = {}; - read_wr.wr_id = 1000 + block_idx; - read_wr.sg_list = &read_sge; - read_wr.num_sge = 1; - read_wr.opcode = IBV_WR_RDMA_READ; - read_wr.send_flags = IBV_SEND_SIGNALED; - read_wr.wr.rdma.remote_addr = remote_addr; - read_wr.wr.rdma.rkey = rkey; + ibv_send_wr read_wr = {}; + read_wr.wr_id = 1000 + block_idx; + read_wr.sg_list = &read_sge; + read_wr.num_sge = 1; + read_wr.opcode = IBV_WR_RDMA_READ; + read_wr.send_flags = IBV_SEND_SIGNALED; + read_wr.wr.rdma.remote_addr = remote_addr; + read_wr.wr.rdma.rkey = rkey; - ibv_send_wr* bad_wr = nullptr; - int ret = ibv_post_send(ctx->qp, &read_wr, &bad_wr); - if (ret != 0) { - ERR("RDMA Read verification failed: %s (errno: %d)", strerror(errno), errno); - return false; - } + ibv_send_wr* bad_wr = nullptr; + int ret = ibv_post_send(ctx->qp, &read_wr, &bad_wr); + if (ret != 0) { + ERR("RDMA Read verification failed: %s (errno: %d)", + strerror(errno), + errno); + return false; + } - if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { - ERR("RDMA Read verification polling failed"); - return false; - } + if (!poll_cq_with_timeout(ctx, RDMA_POLL_CQE_TIMEOUT, 1)) { + ERR("RDMA Read verification polling failed"); + return false; + } - if (KVCacheConfig::getInstance().is_debug_output_enabled()) { - uint8_t* data = reinterpret_cast(ctx->conn.read_bufs[block_idx]); - uint8_t first_byte = data[0]; - uint8_t last_byte = data[block_size_byte - 1]; - DEBUG("Read verification success - Block %zu (Layer: %d, %s:%s): first=%u, last=%u", - block_idx, layer_idx, ip.c_str(), port.c_str(), - static_cast(first_byte), static_cast(last_byte)); - } + if (KVCacheConfig::getInstance().is_debug_output_enabled()) { + uint8_t* data = reinterpret_cast(ctx->conn.read_bufs[block_idx]); + uint8_t first_byte = data[0]; + uint8_t last_byte = data[block_size_byte - 1]; + DEBUG( + "Read verification success - Block %zu (Layer: %d, %s:%s): first=%u, " + "last=%u", + block_idx, + layer_idx, + ip.c_str(), + port.c_str(), + static_cast(first_byte), + static_cast(last_byte)); + } - return true; + return true; } diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/log.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/log.cpp index 603ff6595..6aa69a832 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/log.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/log.cpp @@ -17,14 +17,14 @@ * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include #include "log.h" +#include +#include +#include +#include +#include +#include +#include #include "util.h" static int pid = -1; @@ -33,180 +33,237 @@ static char hostname[64]; char global_log_last_error[1024] = ""; FILE *global_debug_file = stdout; FILE *global_error_file = stdout; -static char global_debug_file_name[PATH_MAX+1] = ""; -static char global_err_file_name[PATH_MAX+1] = ""; +static char global_debug_file_name[PATH_MAX + 1] = ""; +static char global_err_file_name[PATH_MAX + 1] = ""; int global_debug_level = -1; pthread_mutex_t global_debug_lock = PTHREAD_MUTEX_INITIALIZER; pthread_mutex_t global_log_file_lock = PTHREAD_MUTEX_INITIALIZER; -void log_file_init(FILE **kv_cache_log_file, const char *kv_cache_log_file_env, char *logFileName) { - int c = 0; - char *dfn = logFileName; - while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') { - if (kv_cache_log_file_env[c++] != '%') { - *dfn++ = kv_cache_log_file_env[c - 1]; - continue; - } - switch (kv_cache_log_file_env[c++]) { - case '%': // Double % - *dfn++ = '%'; - break; - case 'h': // %h = hostname - dfn += snprintf(dfn, PATH_MAX, "%s", hostname); - break; - case 'p': // %p = pid - dfn += snprintf(dfn, PATH_MAX, "%d", pid); - break; - default: // Echo everything we don't understand - *dfn++ = '%'; - *dfn++ = kv_cache_log_file_env[c - 1]; - break; - } +void log_file_init(FILE **kv_cache_log_file, + const char *kv_cache_log_file_env, + char *logFileName) { + int c = 0; + char *dfn = logFileName; + while (c < PATH_MAX && kv_cache_log_file_env[c] != '\0') { + if (kv_cache_log_file_env[c++] != '%') { + *dfn++ = kv_cache_log_file_env[c - 1]; + continue; } - *dfn = '\0'; - if (logFileName[0] != '\0') { - FILE *file = fopen(logFileName, "w"); - if (file != nullptr) { - setbuf(file, nullptr); // disable buffering - *kv_cache_log_file = file; - } + switch (kv_cache_log_file_env[c++]) { + case '%': // Double % + *dfn++ = '%'; + break; + case 'h': // %h = hostname + dfn += snprintf(dfn, PATH_MAX, "%s", hostname); + break; + case 'p': // %p = pid + dfn += snprintf(dfn, PATH_MAX, "%d", pid); + break; + default: // Echo everything we don't understand + *dfn++ = '%'; + *dfn++ = kv_cache_log_file_env[c - 1]; + break; } + } + *dfn = '\0'; + if (logFileName[0] != '\0') { + FILE *file = fopen(logFileName, "w"); + if (file != nullptr) { + setbuf(file, nullptr); // disable buffering + *kv_cache_log_file = file; + } + } } void recreate_log_file(FILE **kv_cache_log_file, char *logFileName) { - if (logFileName[0] != '\0') { - pthread_mutex_lock(&global_log_file_lock); - FILE *file = fopen(logFileName, "a"); // Use "a" mode to append if file exists, otherwise create it - // close the previous log file if it exists - if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) { - fclose(*kv_cache_log_file); - *kv_cache_log_file = NULL; - } - if (file != NULL) { - setbuf(file, NULL); // disable buffering - *kv_cache_log_file = file; - } - pthread_mutex_unlock(&global_log_file_lock); + if (logFileName[0] != '\0') { + pthread_mutex_lock(&global_log_file_lock); + FILE *file = fopen( + logFileName, + "a"); // Use "a" mode to append if file exists, otherwise create it + // close the previous log file if it exists + if (*kv_cache_log_file != NULL && *kv_cache_log_file != file) { + fclose(*kv_cache_log_file); + *kv_cache_log_file = NULL; } + if (file != NULL) { + setbuf(file, NULL); // disable buffering + *kv_cache_log_file = file; + } + pthread_mutex_unlock(&global_log_file_lock); + } } void debug_init() { - pthread_mutex_lock(&global_debug_lock); - if (global_debug_level != -1) { - pthread_mutex_unlock(&global_debug_lock); - return; - } - - const char* kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED"); - int tempg_kv_cache_debug_level = -1; - - if (kv_cache_debug == NULL) { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; - } else if (strcasecmp(kv_cache_debug, "0") == 0) { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; - } else if (strcasecmp(kv_cache_debug, "1") == 0) { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG; - } else if (strcasecmp(kv_cache_debug, "2") == 0) { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN; - } else if (strcasecmp(kv_cache_debug, "3") == 0) { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR; - } else { - tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; - } - - gethostname(hostname, 64); - pid = getpid(); - - const char* g_kv_cache_debug_fileEnv = KVCacheConfig::getInstance().get_debug_file_path(); - if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_debug_fileEnv != NULL) { - log_file_init(&global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name); - } - - const char* g_kv_cache_error_fileEnv = KVCacheConfig::getInstance().get_error_file_path(); - if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && g_kv_cache_error_fileEnv != NULL) { - log_file_init(&global_error_file, g_kv_cache_error_fileEnv, global_err_file_name); - char buffer[1024]; - size_t len = 0; - char timeBuffer[80]; // Buffer to hold the formatted time - std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); - std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime)); - len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer); - buffer[len++] = '\n'; - if (global_error_file != NULL) { - fwrite(buffer, 1, len, global_error_file); - } - } - __atomic_store_n(&global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE); + pthread_mutex_lock(&global_debug_lock); + if (global_debug_level != -1) { pthread_mutex_unlock(&global_debug_lock); + return; + } + + const char *kv_cache_debug = std::getenv("KV_IS_DEBUG_ENABLED"); + int tempg_kv_cache_debug_level = -1; + + if (kv_cache_debug == NULL) { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; + } else if (strcasecmp(kv_cache_debug, "0") == 0) { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; + } else if (strcasecmp(kv_cache_debug, "1") == 0) { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_DEBUG; + } else if (strcasecmp(kv_cache_debug, "2") == 0) { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_WARN; + } else if (strcasecmp(kv_cache_debug, "3") == 0) { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_ERROR; + } else { + tempg_kv_cache_debug_level = KV_LOG_LEVEL_INFO; + } + + gethostname(hostname, 64); + pid = getpid(); + + const char *g_kv_cache_debug_fileEnv = + KVCacheConfig::getInstance().get_debug_file_path(); + if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && + g_kv_cache_debug_fileEnv != NULL) { + log_file_init( + &global_debug_file, g_kv_cache_debug_fileEnv, global_debug_file_name); + } + + const char *g_kv_cache_error_fileEnv = + KVCacheConfig::getInstance().get_error_file_path(); + if (tempg_kv_cache_debug_level >= KV_LOG_LEVEL_INFO && + g_kv_cache_error_fileEnv != NULL) { + log_file_init( + &global_error_file, g_kv_cache_error_fileEnv, global_err_file_name); + char buffer[1024]; + size_t len = 0; + char timeBuffer[80]; // Buffer to hold the formatted time + std::time_t absoluteTime = + std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + std::strftime(timeBuffer, + sizeof(timeBuffer), + "%Y-%m-%d %H:%M:%S", + std::localtime(&absoluteTime)); + len = snprintf(buffer, sizeof(buffer), "%s KV_CACHE START ", timeBuffer); + buffer[len++] = '\n'; + if (global_error_file != NULL) { + fwrite(buffer, 1, len, global_error_file); + } + } + __atomic_store_n( + &global_debug_level, tempg_kv_cache_debug_level, __ATOMIC_RELEASE); + pthread_mutex_unlock(&global_debug_lock); } /* Common logging function used by the INFO, DEBUG and WARN macros * Also exported to the dynamically loadable Net transport modules so * they can share the debugging mechanisms and output files */ -void debug_log(KVLogLevel level, bool enable_to_terminal, const char *filefunc, int line, const char *fmt, ...) { - if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) { - debug_init(); - } +void debug_log(KVLogLevel level, + bool enable_to_terminal, + const char *filefunc, + int line, + const char *fmt, + ...) { + if (__atomic_load_n(&global_debug_level, __ATOMIC_ACQUIRE) == -1) { + debug_init(); + } - // Save the last error (WARN) as a human readable string - if (level == KV_LOG_LEVEL_WARN) { - pthread_mutex_lock(&global_debug_lock); - va_list vargs; - va_start(vargs, fmt); - (void) vsnprintf(global_log_last_error, sizeof(global_log_last_error), fmt, vargs); - va_end(vargs); - pthread_mutex_unlock(&global_debug_lock); - } + // Save the last error (WARN) as a human readable string + if (level == KV_LOG_LEVEL_WARN) { + pthread_mutex_lock(&global_debug_lock); + va_list vargs; + va_start(vargs, fmt); + (void)vsnprintf( + global_log_last_error, sizeof(global_log_last_error), fmt, vargs); + va_end(vargs); + pthread_mutex_unlock(&global_debug_lock); + } - if (tid == -1) { - tid = syscall(SYS_gettid); - } + if (tid == -1) { + tid = syscall(SYS_gettid); + } - char buffer[1024]; - size_t len = 0; - // Convert timestamp to absolute time and directly use it in the snprintf function - std::time_t absoluteTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); - char timeBuffer[80]; // Buffer to hold the formatted time - std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", std::localtime(&absoluteTime)); + char buffer[1024]; + size_t len = 0; + // Convert timestamp to absolute time and directly use it in the snprintf + // function + std::time_t absoluteTime = + std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + char timeBuffer[80]; // Buffer to hold the formatted time + std::strftime(timeBuffer, + sizeof(timeBuffer), + "%Y-%m-%d %H:%M:%S", + std::localtime(&absoluteTime)); - if (level == KV_LOG_LEVEL_WARN) { - len = snprintf(buffer, sizeof(buffer), "\n%s %s:%d:%d %s:%d KV_CACHE WARN ", - timeBuffer, hostname, pid, tid, filefunc, line); - } else if (level == KV_LOG_LEVEL_INFO) { - len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE INFO ", timeBuffer, hostname, pid, tid); - } else if (level == KV_LOG_LEVEL_DEBUG) { - len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE DEBUG ", timeBuffer, hostname, pid, tid); - } else if (level == KV_LOG_LEVEL_ERROR) { - len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ERROR ", timeBuffer, hostname, pid, tid); - } else { - len = snprintf(buffer, sizeof(buffer), "%s %s:%d:%d KV_CACHE ", timeBuffer, hostname, pid, tid); - } + if (level == KV_LOG_LEVEL_WARN) { + len = snprintf(buffer, + sizeof(buffer), + "\n%s %s:%d:%d %s:%d KV_CACHE WARN ", + timeBuffer, + hostname, + pid, + tid, + filefunc, + line); + } else if (level == KV_LOG_LEVEL_INFO) { + len = snprintf(buffer, + sizeof(buffer), + "%s %s:%d:%d KV_CACHE INFO ", + timeBuffer, + hostname, + pid, + tid); + } else if (level == KV_LOG_LEVEL_DEBUG) { + len = snprintf(buffer, + sizeof(buffer), + "%s %s:%d:%d KV_CACHE DEBUG ", + timeBuffer, + hostname, + pid, + tid); + } else if (level == KV_LOG_LEVEL_ERROR) { + len = snprintf(buffer, + sizeof(buffer), + "%s %s:%d:%d KV_CACHE ERROR ", + timeBuffer, + hostname, + pid, + tid); + } else { + len = snprintf(buffer, + sizeof(buffer), + "%s %s:%d:%d KV_CACHE ", + timeBuffer, + hostname, + pid, + tid); + } - if (len) { - va_list vargs; - va_start(vargs, fmt); - len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs); - va_end(vargs); - // vsnprintf may return len > sizeof(buffer) in the case of a truncated output. - // Rewind len so that we can replace the final \0 by \n - if (len > sizeof(buffer)) { - len = sizeof(buffer) - 1; - } - buffer[len++] = '\n'; - if (access(global_debug_file_name, F_OK) != 0) { - recreate_log_file(&global_debug_file, global_debug_file_name); - } - if (enable_to_terminal) { - fwrite(buffer, 1, len, global_debug_file); - } - if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) { - if (access(global_err_file_name, F_OK) != 0) { - recreate_log_file(&global_error_file, global_err_file_name); - } - if (global_error_file != NULL) { - fwrite(buffer, 1, len, global_error_file); - } - } + if (len) { + va_list vargs; + va_start(vargs, fmt); + len += vsnprintf(buffer + len, sizeof(buffer) - len, fmt, vargs); + va_end(vargs); + // vsnprintf may return len > sizeof(buffer) in the case of a truncated + // output. Rewind len so that we can replace the final \0 by \n + if (len > sizeof(buffer)) { + len = sizeof(buffer) - 1; } + buffer[len++] = '\n'; + if (access(global_debug_file_name, F_OK) != 0) { + recreate_log_file(&global_debug_file, global_debug_file_name); + } + if (enable_to_terminal) { + fwrite(buffer, 1, len, global_debug_file); + } + if (level == KV_LOG_LEVEL_WARN && global_error_file != stdout) { + if (access(global_err_file_name, F_OK) != 0) { + recreate_log_file(&global_error_file, global_err_file_name); + } + if (global_error_file != NULL) { + fwrite(buffer, 1, len, global_error_file); + } + } + } } diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp index aa114ddc2..9ffcb35b2 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp @@ -6,17 +6,22 @@ namespace py = pybind11; PYBIND11_MODULE(rdma_comm, m) { - m.doc() = R"pbdoc(kv cache messager)pbdoc"; - py::class_(m, "RDMACommunicator") - .def(py::init, - std::vector, int, int>()) - .def("connect", &RDMACommunicator::connect) - .def("is_connected", &RDMACommunicator::is_connected) - .def("write_cache", &RDMACommunicator::write_cache); + m.doc() = R"pbdoc(kv cache messager)pbdoc"; + py::class_(m, "RDMACommunicator") + .def(py::init, + std::vector, + int, + int>()) + .def("connect", &RDMACommunicator::connect) + .def("is_connected", &RDMACommunicator::is_connected) + .def("write_cache", &RDMACommunicator::write_cache); #ifdef VERSION_INFO - m.attr("__version__") = VERSION_INFO; + m.attr("__version__") = VERSION_INFO; #else - m.attr("__version__") = "dev"; + m.attr("__version__") = "dev"; #endif } diff --git a/tools/codestyle/pre_commit.sh b/tools/codestyle/pre_commit.sh index bab8bcbba..7060b6215 100644 --- a/tools/codestyle/pre_commit.sh +++ b/tools/codestyle/pre_commit.sh @@ -28,6 +28,11 @@ if ! [[ $(python -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1$2}') -ge 36 please change the default python to higher version." exit 1 fi +if ! [[ $version == *"$VERSION"* ]]; then + # low version of pip may not have the source of clang-format whl + pip install --upgrade pip + pip install clang-format==13.0.0 +fi # Exclude any files under the 'test/ce/server/' directory from code style checks. diff_files=$(git diff --name-only --diff-filter=ACMR ${BRANCH} | grep -v '^tests/ce/server/')