diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 04eb0c568..3bc93088a 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -41,7 +41,9 @@ std::vector BlockAttnKernel( const paddle::Tensor &encoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_map_cpu, const paddle::Tensor &decoder_context_len_cpu, - const paddle::Tensor &decoder_batch_map_cpu) { + const paddle::Tensor &decoder_batch_map_cpu, + const std::string &pos_emb_type="NORMAL", + bool rope_3d=false) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); @@ -72,6 +74,14 @@ std::vector BlockAttnKernel( int enc_batch = enc_batch_tensor.data()[0]; int dec_batch = dec_batch_tensor.data()[0]; int total_enc_len = total_enc_len_tensor.data()[0]; + int rope_max_seqlen = 0; + int rope_3d_num_seqs = 1; + if (rope_3d) { + rope_max_seqlen = rotary_embs.dims()[3]; + rope_3d_num_seqs = rotary_embs.dims()[0]; + } else { + rope_max_seqlen = rotary_embs.dims()[2]; + } auto block_attn_out = paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place()); @@ -151,10 +161,10 @@ std::vector BlockAttnKernel( prefix_lens_vp, // start_tokens param.batch_size, // batch_size 1, // emb_batch_size - rotary_embs.dims()[2], // max_seqlen + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, param.max_batch_size, block_size, max_block_per_seq, "BLHD", - "HLD", "NORMAL", + "HLD", pos_emb_type, !p_kcache_perhead_scale.defined() ? nullptr : p_kcache_perhead_scale.data() + @@ -246,10 +256,10 @@ std::vector BlockAttnKernel( vsl.slot_mapping_vp, // real_batch param.batch_size, // batch_size 1, // emb_batch_size - rotary_embs.dims()[2], // max_seqlen TODO!!double check + rope_max_seqlen, // max_seqlen param.head_num, param.kv_head_num, param.head_dim, param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD", - "NORMAL", + pos_emb_type, !p_kcache_perhead_scale.defined() ? nullptr : p_kcache_perhead_scale.data() + @@ -260,7 +270,9 @@ std::vector BlockAttnKernel( param.kv_head_num, // v_cache_scale_inv nullptr, // k_cache_zp nullptr, // v_cache_zp - false); // b_c8_pc + false, // b_c8_pc + rope_3d, // rope_3d + rope_3d_num_seqs); XFTBLOCK_CHECK_EQ(ret, api::SUCCESS); // attn decode @@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn) "decoder_context_len_cpu", "decoder_batch_map_cpu", }) + .Attrs({"pos_emb_type:std::string", "rope_3d:bool"}) .Outputs({"block_attn_out"}) .SetKernelFn(PD_KERNEL(BlockAttnKernel)) .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape)) diff --git a/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc b/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc new file mode 100644 index 000000000..30ca6d269 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/get_img_boundaries.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +std::vector GetImgBoundaries(const paddle::Tensor& task_input_ids, + const paddle::Tensor& grid_thw, + const int64_t image_patch_id) { + // All tensor in cpu + auto input_ids_ptr = task_input_ids.data(); + int64_t seq_lens_origin = task_input_ids.numel(); + auto grid_thw_ptr = grid_thw.data(); + + int token_times = 4; + int token_idx = 0; + int image_idx = 0; + std::vector img_boundaries, img_nums; + img_boundaries.emplace_back(0); + img_nums.emplace_back(0); + while (token_idx < seq_lens_origin) { + if (input_ids_ptr[token_idx] != image_patch_id) { + do { + token_idx++; + } while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id); + } else { + int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times; + image_idx++; + token_idx += cur_image_token_len; + } + img_boundaries.emplace_back(token_idx); + img_nums.emplace_back(image_idx); + } + + int64_t num_img_boundaries = static_cast(img_boundaries.size()); + auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace()); + + for (int i = 0; i < num_img_boundaries; i++) { + out.data()[i] = img_boundaries[i]; + out.data()[num_img_boundaries + i] = img_nums[i]; + } + + return {out}; +} + +PD_BUILD_OP(get_img_boundaries) + .Inputs({"task_input_ids", "grid_thw"}) + .Attrs({"image_patch_id: int64_t"}) + .Outputs({"img_boundaries"}) + .SetKernelFn(PD_KERNEL(GetImgBoundaries)); diff --git a/custom_ops/xpu_ops/src/ops/moe_layer.cc b/custom_ops/xpu_ops/src/ops/moe_layer.cc index c924a1735..4e8d54cb7 100644 --- a/custom_ops/xpu_ops/src/ops/moe_layer.cc +++ b/custom_ops/xpu_ops/src/ops/moe_layer.cc @@ -145,7 +145,8 @@ std::vector MoeLayerKernel( ? up_gate_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, - std::vector{expert_num, inter_dim, hidden_dim}); + std::vector{expert_num, inter_dim, hidden_dim} + ); xdown_proj_w = std::make_shared( const_cast(down_proj_weight.data()), nullptr, @@ -153,7 +154,8 @@ std::vector MoeLayerKernel( ? down_proj_weight_scale.get_ptr()->data() : nullptr), xftblock_tw, - std::vector{expert_num, hidden_dim, outer_dim}); + std::vector{expert_num, hidden_dim, outer_dim} + ); } std::shared_ptr xup_gate_proj_bias; std::shared_ptr xdown_proj_bias; diff --git a/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc new file mode 100644 index 000000000..a702a465f --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + const int64_t token_num = input.dims()[0]; + const int64_t hidden_size = input.dims()[1]; + const int64_t text_token_num = text_input.dims()[0]; + const int64_t image_token_num = image_input.dims()[0]; + + switch (input.type()) { + case paddle::DataType::BFLOAT16: { + using XPUType = typename XPUTypeTrait::Type; + typedef paddle::bfloat16 data_t; + int r = baidu::xpu::api::plugin::text_image_gather_scatter( + xpu_ctx->x_context(), + reinterpret_cast(input.data()), + reinterpret_cast(text_input.data()), + reinterpret_cast(image_input.data()), + reinterpret_cast(token_type_ids.data()), + reinterpret_cast(text_index.data()), + reinterpret_cast(image_index.data()), + token_num, + text_token_num, + image_token_num, + hidden_size, + is_scatter + ); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter"); + break; + } + default: { + PD_THROW( + "NOT supported data type. Only support BFLOAT16. "); + break; + } + } +} + + +PD_BUILD_OP(text_image_gather_scatter) + .Inputs({"input", + "text_input", + "image_input", + "token_type_ids", + "text_index", + "image_index"}) + .Outputs({"text_input_out", + "image_input_out", + "text_index_out", + "image_index_out"}) + .Attrs({"is_scatter:bool"}) + .SetInplaceMap({{"text_input", "text_input_out"}, + {"image_input", "image_input_out"}, + {"text_index", "text_index_out"}, + {"image_index", "image_index_out"}}) + .SetKernelFn(PD_KERNEL(TextImageGatherScatter)); diff --git a/custom_ops/xpu_ops/src/ops/text_image_index_out.cc b/custom_ops/xpu_ops/src/ops/text_image_index_out.cc new file mode 100644 index 000000000..a0ce15036 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/text_image_index_out.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +void TextImageIndexOut( + const paddle::Tensor& token_type_ids, + const paddle::Tensor& text_index, + const paddle::Tensor& image_index) { + if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type() + != paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) { + PD_THROW("NOT supported data type. Only support BFLOAT16. "); + } + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + const int64_t token_num = token_type_ids.shape()[0]; + int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(), + token_type_ids.data(), + const_cast(text_index.data()), + const_cast(image_index.data()), + token_num); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out"); +} + + +PD_BUILD_OP(text_image_index_out) + .Inputs({"token_type_ids", + "text_index", + "image_index"}) + .Outputs({"text_index_out", + "image_index_out"}) + .SetInplaceMap({{"text_index", "text_index_out"}, + {"image_index", "image_index_out"}}) + .SetKernelFn(PD_KERNEL(TextImageIndexOut)); diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 0033e89de..5ce255956 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x, const TSCALE *scale_in, TY *y, TSCALE *scale_out, int64_t m, int64_t n); +DLL_EXPORT int text_image_index_out(Context* ctx, + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num); + +template +DLL_EXPORT int text_image_gather_scatter(api::Context* ctx, + T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter); /*--------------------------------------- MTP being --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu new file mode 100644 index 000000000..608cda1c6 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu @@ -0,0 +1,175 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/xtdk_io.h" + +namespace xpu3 { +namespace plugin { + +template +static __device__ inline void text_image_gather( + __global_ptr__ T* input, + __global_ptr__ T* text_input, + __global_ptr__ T* image_input, + __global_ptr__ int* token_type_ids, + __global_ptr__ int* text_index, + __global_ptr__ int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + T* input_lm) { + int cid = core_id(); + int clusterid = cluster_id(); + int token_start_cluster; + int token_end_cluster; + int token_start_core; + int token_end_core; + + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + // cluster partition + partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster); + if (token_start_cluster >= token_end_cluster) { + return; + } + int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster + // core partition + partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core); + int rows_core = token_end_core - token_start_core; // total rows for a core + token_start_core += token_start_cluster; + token_end_core += token_start_cluster; + + int read_len; + for (int i = token_start_core; i < token_end_core; i += 1) { + int token_type, text_image_token_idx; + __global_ptr__ T* text_image_input = nullptr; + __global_ptr__ int* text_image_index = nullptr; + + GM2LM(token_type_ids + i, &token_type, sizeof(int)); + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { + text_image_input = image_input; + text_image_index = image_index; + } + GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; + + for (int j = 0; j < hidden_size; j += BUFSIZE) { + read_len = min(hidden_size - j, BUFSIZE); + GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len); + LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len); + } + } +} + +template +static __device__ inline void text_image_scatter( + __global_ptr__ T* input, + __global_ptr__ T* text_input, + __global_ptr__ T* image_input, + __global_ptr__ int* token_type_ids, + __global_ptr__ int* text_index, + __global_ptr__ int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + T* input_lm) { + int cid = core_id(); + int clusterid = cluster_id(); + int token_start_cluster; + int token_end_cluster; + int token_start_core; + int token_end_core; + + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + // cluster partition + partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster); + if (token_start_cluster >= token_end_cluster) { + return; + } + int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster + // core partition + partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core); + int rows_core = token_end_core - token_start_core; // total rows for a core + token_start_core += token_start_cluster; + token_end_core += token_start_cluster; + + int read_len; + for (int i = token_start_core; i < token_end_core; i += 1) { + int token_type, text_image_token_idx; + __global_ptr__ T* text_image_input = nullptr; + __global_ptr__ int* text_image_index = nullptr; + + GM2LM(token_type_ids + i, &token_type, sizeof(int)); + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { + text_image_input = image_input; + text_image_index = image_index; + } + GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int)); + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; + + for (int j = 0; j < hidden_size; j += BUFSIZE) { + read_len = min(hidden_size - j, BUFSIZE); + GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len); + LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len); + } + } +} + +template +__global__ void text_image_gather_scatter( + T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32 + __simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32 + if (is_scatter) { + text_image_scatter( + input, text_input, image_input, token_type_ids, text_index, image_index, + token_num, text_token_num, image_token_num, hidden_size, input_lm); + } else { + text_image_gather( + input, text_input, image_input, token_type_ids, text_index, image_index, + token_num, text_token_num, image_token_num, hidden_size, input_lm); + } +} + + +#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \ + template __global__ void text_image_gather_scatter( \ + T* input, \ + T* text_input, \ + T* image_input, \ + int* token_type_ids, \ + int* text_index, \ + int* image_index, \ + int64_t token_num, \ + int64_t text_token_num, \ + int64_t image_token_num, \ + int64_t hidden_size, \ + bool is_scatter); + +_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16); + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu new file mode 100644 index 000000000..f8c972ef3 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_index_out.xpu @@ -0,0 +1,97 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2025 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_primitive_template.h" + +namespace xpu3 { +namespace plugin { + +static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) { + for (int j = 0; j < size; j++) { + if (lm_x[j] == 0) { + lm_y1[j] = text_count; + text_count += 1; + } else { + lm_y2[j] = images_count; + images_count += 1; + } + } + mfence_lm_sm(); +} + +__global__ void text_image_index_out_kernel( + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num) { + const int cid = core_id(); + const int tid = core_id() * cluster_num() + cluster_id(); + const int nthreads = core_num() * cluster_num(); + if (tid >= 1) return; + constexpr int BUFSIZE = 1024; + constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int); + const int64_t len = token_num; + + __simd__ char buffer0[BUFSIZE * 3]; + __simd__ char buffer1[BUFSIZE * 3]; + __simd__ __shared__ char buffer2[64][BUFSIZE * 2]; + + DoublePtr> buffer_ptr_x((SmPtr((_shared_ptr_ int*)buffer2[cid]))); + TriplePtr> buffer_ptr_y1((LmPtr((int*)buffer0))); + TriplePtr> buffer_ptr_y2((LmPtr((int*)buffer1))); + int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64); + int64_t i = tid * buflen; + int read_size = 0; + int offset = nthreads * buflen; + + int text_count = 0; + int images_count = 0; + + if (i < len) { + read_size = min(buflen, len - i); + buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size); + buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size); + buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size); + mfence(); + } + while (i < len && i + offset < len) { + i = i + offset; + int read_size_next = min(buflen, len - i); + buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next); + buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next); + buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next); + + do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count); + + buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size); + buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size); + buffer_ptr_x.toggle(); + buffer_ptr_y1.toggle(); + buffer_ptr_y2.toggle(); + read_size = read_size_next; + } + if (i < len) { + do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count); + buffer_ptr_y1.gm_store_async(text_index + i, read_size); + buffer_ptr_y2.gm_store(image_index + i, read_size); + } +} +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp new file mode 100644 index 000000000..d4c52293c --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_gather_scatter.cpp @@ -0,0 +1,182 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void text_image_gather_scatter( + T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper( + Context* ctx, + T* input, // shape [token_num, hidden_size] + T* text_input, // shape [text_token_num, hidden_size] + T* image_input, // shape [image_token_num, hidden_size] + int* token_type_ids,// shape [token_num], 0 for text, 1 for image + int* text_index, // shape [token_num], mapping from input to text_input + int* image_index, // shape [token_num], mapping from input to image_input + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + + if (is_scatter) { + // Scatter mode: input -> text_input/image_input + for (int64_t i = 0; i < token_num; i++) { + int token_type = token_type_ids[i]; + + T* text_image_input = nullptr; + int* text_image_index = nullptr; + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { // token_type == 1 + text_image_input = image_input; + text_image_index = image_index; + } + + int text_image_token_idx = text_image_index[i]; + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; + + for (int64_t j = 0; j < hidden_size; j++) { + T value = input[input_offset + j]; + text_image_input[text_image_offset + j] = value; + } + } + } else { + // Gather mode: text_input/image_input -> input + for (int64_t i = 0; i < token_num; i++) { + int token_type = token_type_ids[i]; + + T* text_image_input = nullptr; + int* text_image_index = nullptr; + if (token_type == 0) { + text_image_input = text_input; + text_image_index = text_index; + } else { // token_type == 1 + text_image_input = image_input; + text_image_index = image_index; + } + + int text_image_token_idx = text_image_index[i]; + int input_offset = i * hidden_size; + int text_image_offset = text_image_token_idx * hidden_size; + + for (int64_t j = 0; j < hidden_size; j++) { + T value = text_image_input[text_image_offset + j]; + input[input_offset + j] = value; + } + } + } + return api::SUCCESS; +} + +template +static int xpu3_wrapper( + Context* ctx, + T* input, + T* text_input, + T* image_input, + int* token_type_ids, + int* text_index, + int* image_index, + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + xpu3::plugin::text_image_gather_scatter <<ncluster(), 64, ctx->xpu_stream>>>( + input, text_input, image_input, token_type_ids, text_index, image_index, + token_num, text_token_num, image_token_num, hidden_size, is_scatter + ); + return api::SUCCESS; +} + + +template +int text_image_gather_scatter( + Context* ctx, + T* input, // shape [token_num, hidden_size] + T* text_input, // shape [text_token_num, hidden_size] + T* image_input, // shape [image_token_num, hidden_size] + int* token_type_ids,// shape [token_num], 0 for text, 1 for image + int* text_index, // shape [token_num], mapping from input to text_input + int* image_index, // shape [token_num], mapping from input to image_input + int64_t token_num, + int64_t text_token_num, + int64_t image_token_num, + int64_t hidden_size, + bool is_scatter) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T); + WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index); + WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input); + if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size] + WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input); + } + if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size] + WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input); + } + WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids); + WRAPPER_CHECK_PTR(ctx, int, token_num, text_index); + WRAPPER_CHECK_PTR(ctx, int, token_num, image_index); + WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper( + ctx, input, text_input, image_input, token_type_ids, text_index, image_index, + token_num, text_token_num, image_token_num, hidden_size, is_scatter + ); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper( + ctx, input, text_input, image_input, token_type_ids, text_index, image_index, + token_num, text_token_num, image_token_num, hidden_size, is_scatter + ); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + + +template int text_image_gather_scatter( + Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp new file mode 100644 index 000000000..3a2cd44c4 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/text_image_index_out.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context* ctx, + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num) { + int text_count = 0; + int image_count = 0; + + for (int64_t i = 0; i < token_num; ++i) { + if (token_type_ids[i] == 0) { + text_index[i] = text_count; + ++text_count; + } else { + image_index[i] = image_count; + ++image_count; + } + } + return api::SUCCESS; + +} + +static int xpu3_wrapper(Context* ctx, + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num) { + + xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>( + token_type_ids, + text_index, + image_index, + token_num); + return api::SUCCESS; +} + +int text_image_index_out(Context* ctx, + const int* token_type_ids, // x + int* text_index, // y1 + int* image_index, // y2 + const int64_t token_num) { + + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int); + WRAPPER_DUMP_PARAM4( + ctx, token_type_ids, text_index, image_index, token_num); + WRAPPER_DUMP(ctx); + WRAPPER_ASSERT_GT(ctx, token_num, 0); + WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids); + WRAPPER_CHECK_PTR(ctx, int, token_num, text_index); + WRAPPER_CHECK_PTR(ctx, int, token_num, image_index); + + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + token_type_ids, + text_index, + image_index, + token_num); + } else if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + token_type_ids, + text_index, + image_index, + token_num); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 9d8925cea..075a77f24 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -30,6 +30,7 @@ import paddle from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.platforms import current_platform from fastdeploy.utils import llm_logger @@ -157,6 +158,7 @@ class ResourceManagerV1(ResourceManager): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) + request.with_image = False if not self.config.model_config.enable_mm: return num_new_tokens @@ -219,7 +221,10 @@ class ResourceManagerV1(ResourceManager): grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) grid_thw = paddle.to_tensor(grid_thw, dtype="int64") - from fastdeploy.model_executor.ops.gpu import get_img_boundaries + if current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import get_img_boundaries + else: + from fastdeploy.model_executor.ops.gpu import get_img_boundaries request.multimodal_img_boundaries = get_img_boundaries( task_input_ids=input_ids, grid_thw=grid_thw, image_patch_id=image_patch_id diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 2d812b4e9..10608676c 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -232,6 +232,8 @@ class XPUForwardMeta(ForwardMeta): dec_batch: Optional[paddle.Tensor] = None # total_enc_len: Optional[paddle.Tensor] = None + # position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM' + pos_emb_type: Optional[str] = "NORMAL" @dataclass diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 938693738..3cd64b6fa 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -183,5 +183,7 @@ class XPUAttentionBackend(AttentionBackend): forward_meta.encoder_batch_map_cpu, forward_meta.decoder_context_len_cpu, forward_meta.decoder_batch_map_cpu, + forward_meta.pos_emb_type, + self.rope_3d, ) return res diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 3c571697f..adb39d05d 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -72,7 +72,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): layer.top_k, False, # moe group, used in deepseek ) - if layer.tp_size > 1: + if layer.reduce_results and layer.tp_size > 1: from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce, ) @@ -252,7 +252,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): layer.top_k, False, # moe group, used in deepseek ) - if layer.tp_size > 1: + if layer.reduce_results and layer.tp_size > 1: from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce, ) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index de1c405af..7be6d2b5c 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -31,6 +31,7 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.decorator import ( cuda_graph_buffers, support_graph_optimization, @@ -44,20 +45,15 @@ from fastdeploy.model_executor.models.ernie4_5_moe import ( Ernie4_5_Attention, Ernie4_5_MLP, ) +from fastdeploy.model_executor.models.ernie4_5_vl.image_op import ( + text_image_gather_scatter, + text_image_index_out, +) from fastdeploy.model_executor.models.model_base import ( ModelCategory, ModelForCasualLM, ModelRegistry, ) -from fastdeploy.platforms import current_platform - -if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import ( - text_image_gather_scatter, - text_image_index_out, - ) - -from fastdeploy.model_executor.forward_meta import ForwardMeta class Ernie4_5_VLMLP(Ernie4_5_MLP): diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py new file mode 100644 index 000000000..4324d921f --- /dev/null +++ b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py @@ -0,0 +1,32 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + text_image_gather_scatter, + text_image_index_out, + ) +elif current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + text_image_gather_scatter, + text_image_index_out, + ) +else: + raise ImportError("Unsupported platform, only support CUDA and XPU") + +__all__ = ["text_image_gather_scatter", "text_image_index_out"] diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 452a56aa2..b5fa856ce 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -25,6 +25,7 @@ from paddle import nn from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType +from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, @@ -34,10 +35,11 @@ from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) -from fastdeploy.model_executor.layers.rotary_embedding import get_rope +from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.model_loader import get_model_loader +from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.ops.xpu import ( adjust_batch, get_infer_param, @@ -201,6 +203,45 @@ def xpu_post_process( update_inputs, ) + # handle vl: + if model_output.enable_thinking: + exists_think_end = sampled_token_ids == model_output.think_end_id + paddle.assign( + paddle.where( + exists_think_end, + model_output.need_think_end - 1, + model_output.need_think_end, + ), + model_output.need_think_end, + ) + + paddle.assign( + paddle.where( + model_output.need_think_end.cast("bool"), + model_output.reasoning_index - 1, + model_output.reasoning_index, + ), + model_output.reasoning_index, + ) + + stop_wo_think = ( + (sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True) + | (model_output.reasoning_index == 0) + ) & (model_output.need_think_end > 0) + sampled_token_ids = paddle.where( + stop_wo_think, + model_output.think_end_id, + sampled_token_ids, + ) + paddle.assign( + paddle.where( + stop_wo_think, + model_output.need_think_end - 1, + model_output.need_think_end, + ), + model_output.need_think_end, + ) + # 1. Set stop value paddle.assign( paddle.where( @@ -340,11 +381,36 @@ class XPUModelRunner(ModelRunnerBase): def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int): super().__init__(fd_config=fd_config, device=device) + self.enable_mm = self.model_config.enable_mm self.rank = rank self.local_rank = local_rank + self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop + + # VL model config: + if self.enable_mm: + self._init_image_preprocess() + + self.amp_black = [ + "reduce_sum", + "c_softmax_with_cross_entropy", + "elementwise_div", + "sin", + "cos", + "sort", + "multinomial", + ] + self.amp_white = [ + "lookup_table", + "lookup_table_v2", + "flash_attn", + "matmul", + "matmul_v2", + "fused_gemm_epilogue", + ] # Sampler - self.sampler = Sampler() + # TODU(lilujia): sync with GPU + self.sampler = Sampler(fd_config) # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] @@ -364,18 +430,28 @@ class XPUModelRunner(ModelRunnerBase): ).cpu() # Initialize attention Backend - # Note(gonshaotian): Currently, all attention layers share one attention backend instance. + # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance. # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] - self.initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None + def exist_prefill(self): + """ + check whether prefill stage exist + """ + if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: + return 1 + else: + return 0 + def insert_tasks_v1(self, req_dicts: List[Request]): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 + req_dict: A list of Request dict + num_running_requests: batch_size """ # NOTE(luotingdan): Lazy initialize kv cache if "caches" not in self.share_inputs: @@ -388,11 +464,54 @@ class XPUModelRunner(ModelRunnerBase): request = req_dicts[i] idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task - logger.debug(f"Handle prefill request {request} at idx {idx}") prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index - input_ids = request.prompt_token_ids + request.output_token_ids + if self.enable_mm: + inputs = request.multimodal_inputs + if request.with_image: + vision_inputs = {} + vision_inputs["input_ids"] = paddle.to_tensor( + inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["token_type_ids"] = paddle.to_tensor( + inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["image_type_ids"] = paddle.to_tensor( + inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], + dtype=paddle.int64, + ) + vision_inputs["images"] = paddle.to_tensor( + inputs["images"][request.image_start : request.image_end], dtype="uint8" + ) + vision_inputs["grid_thw"] = paddle.to_tensor( + inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" + ) + self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) + else: + self.share_inputs["image_features"] = None + + if inputs["position_ids"] is not None: + position_ids = paddle.to_tensor( + request.multimodal_inputs["position_ids"], + dtype="int64", + ).unsqueeze([0]) + else: + position_ids = None + + enable_thinking = request.get("enable_thinking", True) + enable_thinking = enable_thinking if enable_thinking is not None else True + self.share_inputs["enable_thinking"][:] = enable_thinking + self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 + self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) + self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( + position_ids, request.get("max_tokens", 2048) + ) + + if len(request.output_token_ids) == 0: + input_ids = request.prompt_token_ids + else: + input_ids = request.prompt_token_ids + request.output_token_ids logger.debug( f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}" ) @@ -475,41 +594,86 @@ class XPUModelRunner(ModelRunnerBase): if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): - request.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") - self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( - request.get("stop_token_ids"), dtype="int64" + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True - def process_prefill_inputs(self, req_dicts: List[Request]): + def insert_prefill_inputs(self, req_dicts: List[Request]): """Process inputs for prefill tasks and update share_inputs buffer""" req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx - length = request.prompt_token_ids_len + length = len(request.prompt_token_ids) + assert length > 0, "The prompt requested must not be empty." + self.share_inputs["pre_ids"][idx : idx + 1] = -1 + self.share_inputs["step_idx"][idx : idx + 1] = 0 self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + if self.enable_mm: + inputs = self._preprocess_mm_task(request.multimodal_inputs) + if inputs.get("images") is not None: + self.share_inputs["image_features"] = self.extract_vision_features(inputs) + else: + # Compatible with the situation that lacks images and videos + self.share_inputs["image_features"] = None + position_ids = inputs["position_ids"] + length = inputs["input_ids"].shape[1] + self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"] + else: + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["prompt_lens"][idx : idx + 1] = length + + if self.enable_mm: + enable_thinking = request.get("enable_thinking", True) + enable_thinking = enable_thinking if enable_thinking is not None else True + self.share_inputs["enable_thinking"][:] = enable_thinking + self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 + self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) + self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( + position_ids, request.get("max_tokens", 2048) + ) + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + + def get_attr_from_request(request, attr, default_value=None): + res = request.get(attr, default_value) + if res is not None: + return res + else: + return default_value + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) - self.share_inputs["pre_ids"][idx : idx + 1] = -1 - self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) - self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) - self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) - self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) - self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) - self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length - self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - self.share_inputs["step_idx"][idx : idx + 1] = 0 - self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request( + request, "repetition_penalty", 1.0 + ) + self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request( + request, "frequency_penalty", 0.0 + ) + self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( + request, "presence_penalty", 0.0 + ) + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( "max_tokens", self.model_config.max_model_len ) @@ -540,11 +704,15 @@ class XPUModelRunner(ModelRunnerBase): if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): - request.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") - self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( - request.get("stop_token_ids"), dtype="int64" + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 self.share_inputs["not_need_stop"][0] = True @@ -565,6 +733,11 @@ class XPUModelRunner(ModelRunnerBase): self.model_config.pad_token_id, dtype="int64", ) + self.share_inputs["prompt_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + self.model_config.pad_token_id, + dtype="int64", + ) self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") @@ -627,13 +800,15 @@ class XPUModelRunner(ModelRunnerBase): # Initialize rotary position embedding tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) + # TODO(gongshaotian): move to models - self.share_inputs["rope_emb"] = get_rope( - rotary_dim=self.model_config.head_dim, - position_ids=tmp_position_ids, - base=self.model_config.rope_theta, - model_config=self.model_config, - ) + if not self.enable_mm: + self.share_inputs["rope_emb"] = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=tmp_position_ids, + base=self.model_config.rope_theta, + model_config=self.model_config, + ) # Set block tables pre_max_block_num = ( @@ -654,18 +829,40 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32") # Initialize stop seqs - self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32") + self.share_inputs["stop_seqs_len"] = paddle.full( + [max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32" + ) self.share_inputs["stop_seqs"] = paddle.full( [ + max_num_seqs, self.model_config.max_stop_seqs_num, self.model_config.stop_seqs_max_len, ], -1, - dtype="int32", + dtype="int64", ) + if self.enable_mm: + head_dim = self.model_config.head_dim + self.share_inputs["rope_emb"] = paddle.full( + shape=[ + max_num_seqs, + 2, + 1, + self.parallel_config.max_model_len, + 1, + head_dim // 2, + ], + fill_value=0, + dtype="float32", + ) + self.share_inputs["image_features"] = None + self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + def _prepare_inputs(self, is_dummy_run=False) -> None: - """prepare the model inputs""" + """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run: recover_decode_task( self.share_inputs["stop_flags"], @@ -689,10 +886,13 @@ class XPUModelRunner(ModelRunnerBase): # Update bad tokens len max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) + if self.enable_mm: # pos_emb_type is different in EB and VL + self.forward_meta.pos_emb_type = "HALF_HEAD_DIM" self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() # Get sampling metadata + # TODU(lilujia): sync with GPU self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], @@ -703,12 +903,16 @@ class XPUModelRunner(ModelRunnerBase): seed=self.share_inputs["infer_seed"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], + prompt_ids=self.share_inputs["prompt_ids"], + prompt_lens=self.share_inputs["prompt_lens"], frequency_penalties=self.share_inputs["frequency_score"], presence_penalties=self.share_inputs["presence_score"], repetition_penalties=self.share_inputs["penalty_score"], min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], eos_token_ids=self.share_inputs["eos_token_id"], + enable_early_stop=self.enable_early_stop, + stop_flags=self.share_inputs["stop_flags"], ) def load_model(self) -> None: @@ -723,7 +927,7 @@ class XPUModelRunner(ModelRunnerBase): # 3. Load drafter model(for speculative decoding) def get_model(self) -> nn.Layer: - """get current model""" + """Get current model""" return self.model def initialize_attention_backend(self): @@ -741,6 +945,7 @@ class XPUModelRunner(ModelRunnerBase): cache_kvs = {} max_block_num = self.num_gpu_blocks + # Get kv cache dtype cache_type = self.parallel_config.dtype kv_cache_quant_type = None @@ -800,33 +1005,6 @@ class XPUModelRunner(ModelRunnerBase): ) self.attn_backends.append(attn_backend) - def capture_model(self) -> None: - """ - Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' - """ - logger.warn("XPU not support cuda graph currently") - pass - - @sot_warmup_guard(True) - def sot_warmup(self) -> None: - start_time = time.perf_counter() - for batch_size in self.sot_warmup_sizes: - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=batch_size, - ) - logger.info(f"SOT warmup the model with the batch size:{batch_size}") - logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") - - def exist_prefill(self): - """ - check whether prefill stage exist - """ - if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: - return 1 - else: - return 0 - def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int): """Set dummy prefill inputs to share_inputs""" full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10) @@ -838,7 +1016,7 @@ class XPUModelRunner(ModelRunnerBase): for i in range(batch_size): idx = i self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) - + self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length @@ -897,6 +1075,24 @@ class XPUModelRunner(ModelRunnerBase): else: paddle.device.xpu.set_debug_level(debug_level) + def capture_model(self) -> None: + """ + Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' + """ + logger.warn("XPU not support cuda graph currently") + pass + + @sot_warmup_guard(True) + def sot_warmup(self) -> None: + start_time = time.perf_counter() + for batch_size in self.sot_warmup_sizes: + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + ) + logger.info(f"SOT warmup the model with the batch size:{batch_size}") + logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") + def execute_model( self, model_forward_batch: Optional[List[Request]] = None, @@ -921,13 +1117,20 @@ class XPUModelRunner(ModelRunnerBase): # 2. Padding inputs for cuda grph # 3. Execute model - model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta) + if self.enable_mm: + model_output = self.model( + self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta + ) + else: + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta, + ) - hiddden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta) + hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta) # 4. Compute logits, Sample - logits = self.model.compute_logits(hiddden_states) - + logits = self.model.compute_logits(hidden_states) sampler_output = self.sampler(logits, self.sampling_metadata) # 5. Speculative decode @@ -947,15 +1150,21 @@ class XPUModelRunner(ModelRunnerBase): seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], + # 投机解码 + full_hidden_states=None, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.local_rank, use_ep=self.parallel_config.use_ep, - # 投机解码 - full_hidden_states=None, draft_tokens=None, actual_draft_token_num=None, accept_tokens=None, accept_num=None, + enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), + think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), + need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None), + reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], ) xpu_post_process( sampled_token_ids=sampler_output.sampled_token_ids, @@ -984,13 +1193,43 @@ class XPUModelRunner(ModelRunnerBase): @profile_run_guard(True) def profile_run(self) -> None: - """Execute a forward pass with dummy inputs to profile the memory usage of the model.""" + """Execute a forward pass with dummy inputs to profile the memory usage of the model""" + + self.num_gpu_blocks = self.parallel_config.total_block_num + self.initialize_kv_cache() self._dummy_run( num_tokens=int(self.scheduler_config.max_num_batched_tokens), batch_size=min(self.scheduler_config.max_num_seqs, 1), ) + def update_share_input_block_num(self, num_gpu_blocks: int) -> None: + """ + Set a globally unified block number and update the model's shared input. + Args: + num_gpu_blocks: + """ + self.num_gpu_blocks = num_gpu_blocks + + # Reset block table and kv cache with global block num + self.initialize_kv_cache() + + # Reset free list + free_list = list( + range( + self.num_gpu_blocks - 1, + int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, + -1, + ) + ) + self.free_list_len = len(free_list) + self.share_inputs.update( + { + "free_list": paddle.to_tensor(free_list, dtype="int32"), + "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), + } + ) + def clear_block_table(self) -> None: """ Clear the block tables and kv cache after profiling. @@ -1025,41 +1264,135 @@ class XPUModelRunner(ModelRunnerBase): byte_of_dtype = 2 hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads - required_memory = ( - byte_of_dtype - * 2 # k + v - * (self.cache_config.block_size * hidden_dim) - * self.model_config.num_hidden_layers - ) + num_layers = self.model_config.num_hidden_layers + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v return required_memory - def update_share_input_block_num(self, num_gpu_blocks: int) -> None: - """ - Set a globally unified block number and update the model's shared input. - Args: - num_gpu_blocks: - """ - self.num_gpu_blocks = num_gpu_blocks - - # Reset block table and kv cache with global block num - self.initialize_kv_cache() - - # Reset free list - free_list = list( - range( - self.num_gpu_blocks - 1, - int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, - -1, - ) - ) - self.free_list_len = len(free_list) - self.share_inputs.update( - { - "free_list": paddle.to_tensor(free_list, dtype="int32"), - "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), - } - ) - def not_need_stop(self) -> bool: - """ """ + """Stop decoding if the tensor meets the termination condition""" return self.share_inputs["not_need_stop"][0] + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def _init_image_preprocess(self) -> None: + processor = DataProcessor( + tokenizer_name=self.model_config.model, + image_preprocessor_name=str(self.model_config.model), + ) + processor.eval() + image_preprocess = processor.image_preprocessor + image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape( + [1, 3, 1, 1] + ) + image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape( + [1, 3, 1, 1] + ) + image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32") + image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave( + self.model_config.vision_config.patch_size**2 * 1, -1 + ) + image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave( + self.model_config.vision_config.patch_size**2 * 1, -1 + ) + self.image_preprocess = image_preprocess + + def _preprocess_mm_task(self, one: dict) -> None: + """process batch""" + + input_ids = one["input_ids"][np.newaxis, :] + input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) + token_type_ids = one["token_type_ids"][np.newaxis, :] + token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + + if one["images"] is not None: + image_type_ids = one["image_type_ids"][np.newaxis, :] + images = one["images"] + image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) + images = paddle.to_tensor(images, dtype="uint8") + grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") + else: + image_type_ids = None + images = None + grid_thw = None + + if one["position_ids"] is not None: + position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0]) + else: + position_ids = None + + result = dict( + input_ids=input_ids, + image_type_ids=image_type_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + grid_thw=grid_thw, + images=images, + ) + return result + + @paddle.no_grad() + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + """extract_vision_features""" + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + + images = inputs["images"].cast("float32") + images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor + images = images / self.image_preprocess.image_std_tensor + images = images.cast("bfloat16") + + token_type_ids = inputs["token_type_ids"] + token_type_ids_w_video = token_type_ids + input_ids = inputs["input_ids"] + # convert to img patch id + # TODO(lulinjun): may need to check model_config and model_cfg + image_mask = input_ids == self.model_config.im_patch_id + image_type_ids = inputs["image_type_ids"] + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.parallel_config.dtype, + ): + image_features = self.model.vision_model.extract_feature(images, grid_thw) + if self.parallel_config.tensor_parallel_size > 1: + S, C = image_features.shape + image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) + image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea + image_features = image_features.reshape([S, -1]) + image_features = self.model.resampler_model( + image_features, + image_mask, + token_type_ids_w_video, + image_type_ids, + grid_thw, + ) + return image_features + + @paddle.no_grad() + def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor: + """prepare_rope3d""" + + prefix_max_position_ids = paddle.max(position_ids) + 1 + dec_pos_ids = paddle.tile( + paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1), + [1, 1, 3], + ) + dec_pos_ids = dec_pos_ids + prefix_max_position_ids + position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1) + + rope_emb = get_rope_3d( + position_ids=position_ids_3d_real, + rotary_dim=self.model_config.head_dim, + partial_rotary_factor=1.0, + base=self.model_config.rope_theta, + max_position=self.parallel_config.max_model_len, + freq_allocation=getattr(self.model_config, "freq_allocation", 20), + model_type=self.model_config.model_type, + ) + return rope_emb diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 66d0d9cb9..ef7450ec7 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -51,12 +51,13 @@ class XpuWorker(WorkerBase): """Initialize device and Construct model runner""" if paddle.is_compiled_with_xpu(): # Set environment variable + self.device_ids = self.parallel_config.device_ids.split(",") self.device = f"xpu:{self.local_rank}" paddle.device.set_device(self.device) paddle.set_default_dtype(self.parallel_config.dtype) - self.device_ids = self.parallel_config.device_ids.split(",") gc.collect() + paddle.device.xpu.empty_cache() else: raise RuntimeError(f"Not support device type: {self.device_config.device}") @@ -69,12 +70,11 @@ class XpuWorker(WorkerBase): local_rank=self.local_rank, ) - def graph_optimize_and_warm_up_model(self) -> None: + def exist_prefill(self): """ - Perform the warm-up and the graph optimization + check whether prefill stage exist """ - if self.model_runner.graph_opt_level >= 1: - self.model_runner.sot_warmup() + return self.model_runner.exist_prefill() def determine_available_memory(self) -> int: """ @@ -133,20 +133,17 @@ class XpuWorker(WorkerBase): paddle.device.xpu.empty_cache() return available_kv_cache_memory # approximate value - def cal_theortical_kvcache(self) -> int: - """ """ - return self.model_runner.cal_theortical_kvcache() - def load_model(self) -> None: - """ """ + """Load model""" self.model_runner.load_model() def get_model(self) -> nn.Layer: - """ """ + """Get current model""" return self.model_runner.get_model() def initialize_cache(self, num_gpu_blocks: int) -> None: - """ """ + """Initizlize the KV Cache with accurate num_gpu_blocks""" + # accurate cache size self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( @@ -158,12 +155,6 @@ class XpuWorker(WorkerBase): """ """ return self.model_runner.execute_model(model_forward_batch, num_running_requests, is_dummy_run) - def exist_prefill(self): - """ - check whether prefill stage exist - """ - return self.model_runner.exist_prefill() - def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int = -1) -> None: """Process new requests and then start the decode loop TODO(gongshaotian):The scheduler should schedule the handling of prefill, @@ -172,8 +163,19 @@ class XpuWorker(WorkerBase): if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.model_runner.insert_tasks_v1(req_dicts=req_dicts) else: - self.model_runner.process_prefill_inputs(req_dicts=req_dicts) + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + + def graph_optimize_and_warm_up_model(self) -> None: + """ + Perform the warm-up and the graph optimization + """ + if self.model_runner.graph_opt_level >= 1: + self.model_runner.sot_warmup() def check_health(self) -> bool: """ """ return True + + def cal_theortical_kvcache(self) -> int: + """Calculate the block memory required""" + return self.model_runner.cal_theortical_kvcache()