[XPU] support XPU VL model inference (#4030)

* [XPU] support XPU VL model inference

* fix image op import and device check

* rebase develop

* fix perf
This commit is contained in:
Lucas
2025-09-25 14:34:15 +08:00
committed by GitHub
parent e36eccfdad
commit 87179cb744
18 changed files with 1300 additions and 146 deletions

View File

@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor &encoder_seq_lod_cpu, const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu, const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_context_len_cpu, const paddle::Tensor &decoder_context_len_cpu,
const paddle::Tensor &decoder_batch_map_cpu) { const paddle::Tensor &decoder_batch_map_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place); paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int enc_batch = enc_batch_tensor.data<int32_t>()[0]; int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0]; int dec_batch = dec_batch_tensor.data<int32_t>()[0];
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0]; int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
if (rope_3d) {
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
} else {
rope_max_seqlen = rotary_embs.dims()[2];
}
auto block_attn_out = auto block_attn_out =
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place()); paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
prefix_lens_vp, // start_tokens prefix_lens_vp, // start_tokens
param.batch_size, // batch_size param.batch_size, // batch_size
1, // emb_batch_size 1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim, param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD", param.max_batch_size, block_size, max_block_per_seq, "BLHD",
"HLD", "NORMAL", "HLD", pos_emb_type,
!p_kcache_perhead_scale.defined() !p_kcache_perhead_scale.defined()
? nullptr ? nullptr
: p_kcache_perhead_scale.data<float>() + : p_kcache_perhead_scale.data<float>() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.slot_mapping_vp, // real_batch vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size param.batch_size, // batch_size
1, // emb_batch_size 1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen TODO!!double check rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim, param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD", param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
"NORMAL", pos_emb_type,
!p_kcache_perhead_scale.defined() !p_kcache_perhead_scale.defined()
? nullptr ? nullptr
: p_kcache_perhead_scale.data<float>() + : p_kcache_perhead_scale.data<float>() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.kv_head_num, // v_cache_scale_inv param.kv_head_num, // v_cache_scale_inv
nullptr, // k_cache_zp nullptr, // k_cache_zp
nullptr, // v_cache_zp nullptr, // v_cache_zp
false); // b_c8_pc false, // b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS); XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
// attn decode // attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
"decoder_context_len_cpu", "decoder_context_len_cpu",
"decoder_batch_map_cpu", "decoder_batch_map_cpu",
}) })
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Outputs({"block_attn_out"}) .Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttnKernel)) .SetKernelFn(PD_KERNEL(BlockAttnKernel))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();
int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
}
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}
return {out};
}
PD_BUILD_OP(get_img_boundaries)
.Inputs({"task_input_ids", "grid_thw"})
.Attrs({"image_patch_id: int64_t"})
.Outputs({"img_boundaries"})
.SetKernelFn(PD_KERNEL(GetImgBoundaries));

View File

@@ -145,7 +145,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? up_gate_proj_weight_scale.get_ptr()->data<float>() ? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr), : nullptr),
xftblock_tw, xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}); std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
);
xdown_proj_w = std::make_shared<xftblock::Tensor>( xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr, const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
@@ -153,7 +154,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? down_proj_weight_scale.get_ptr()->data<float>() ? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr), : nullptr),
xftblock_tw, xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}); std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
);
} }
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias; std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias; std::shared_ptr<xftblock::Tensor> xdown_proj_bias;

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
}
}
PD_BUILD_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
}
PD_BUILD_OP(text_image_index_out)
.Inputs({"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_index_out",
"image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));

View File

@@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y, const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n); TSCALE *scale_out, int64_t m, int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
template <typename T>
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
/*--------------------------------------- MTP being --------------------------------------------*/ /*--------------------------------------- MTP being --------------------------------------------*/

View File

@@ -0,0 +1,175 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
template <typename T>
static __device__ inline void text_image_gather(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
static __device__ inline void text_image_scatter(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
__global__ void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
if (is_scatter) {
text_image_scatter(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
} else {
text_image_gather(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
}
}
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
template __global__ void text_image_gather_scatter<T>( \
T* input, \
T* text_input, \
T* image_input, \
int* token_type_ids, \
int* text_index, \
int* image_index, \
int64_t token_num, \
int64_t text_token_num, \
int64_t image_token_num, \
int64_t hidden_size, \
bool is_scatter);
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,97 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2025 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
for (int j = 0; j < size; j++) {
if (lm_x[j] == 0) {
lm_y1[j] = text_count;
text_count += 1;
} else {
lm_y2[j] = images_count;
images_count += 1;
}
}
mfence_lm_sm();
}
__global__ void text_image_index_out_kernel(
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
if (tid >= 1) return;
constexpr int BUFSIZE = 1024;
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
const int64_t len = token_num;
__simd__ char buffer0[BUFSIZE * 3];
__simd__ char buffer1[BUFSIZE * 3];
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
int64_t i = tid * buflen;
int read_size = 0;
int offset = nthreads * buflen;
int text_count = 0;
int images_count = 0;
if (i < len) {
read_size = min<int64_t>(buflen, len - i);
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
mfence();
}
while (i < len && i + offset < len) {
i = i + offset;
int read_size_next = min<int64_t>(buflen, len - i);
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
buffer_ptr_x.toggle();
buffer_ptr_y1.toggle();
buffer_ptr_y2.toggle();
read_size = read_size_next;
}
if (i < len) {
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
buffer_ptr_y2.gm_store(image_index + i, read_size);
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,182 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
if (is_scatter) {
// Scatter mode: input -> text_input/image_input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = input[input_offset + j];
text_image_input[text_image_offset + j] = value;
}
}
} else {
// Gather mode: text_input/image_input -> input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = text_image_input[text_image_offset + j];
input[input_offset + j] = value;
}
}
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(
Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
return api::SUCCESS;
}
template <typename T>
int text_image_gather_scatter(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
}
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
}
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int text_image_gather_scatter(
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,103 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
int text_count = 0;
int image_count = 0;
for (int64_t i = 0; i < token_num; ++i) {
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
++text_count;
} else {
image_index[i] = image_count;
++image_count;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids,
text_index,
image_index,
token_num);
return api::SUCCESS;
}
int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
WRAPPER_DUMP_PARAM4(
ctx, token_type_ids, text_index, image_index, token_num);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, token_num, 0);
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -30,6 +30,7 @@ import paddle
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger from fastdeploy.utils import llm_logger
@@ -157,6 +158,7 @@ class ResourceManagerV1(ResourceManager):
# TODO: set condition to new _get_num_new_tokens # TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
request.with_image = False
if not self.config.model_config.enable_mm: if not self.config.model_config.enable_mm:
return num_new_tokens return num_new_tokens
@@ -219,6 +221,9 @@ class ResourceManagerV1(ResourceManager):
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64") grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_img_boundaries
else:
from fastdeploy.model_executor.ops.gpu import get_img_boundaries from fastdeploy.model_executor.ops.gpu import get_img_boundaries
request.multimodal_img_boundaries = get_img_boundaries( request.multimodal_img_boundaries = get_img_boundaries(

View File

@@ -232,6 +232,8 @@ class XPUForwardMeta(ForwardMeta):
dec_batch: Optional[paddle.Tensor] = None dec_batch: Optional[paddle.Tensor] = None
# #
total_enc_len: Optional[paddle.Tensor] = None total_enc_len: Optional[paddle.Tensor] = None
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
pos_emb_type: Optional[str] = "NORMAL"
@dataclass @dataclass

View File

@@ -183,5 +183,7 @@ class XPUAttentionBackend(AttentionBackend):
forward_meta.encoder_batch_map_cpu, forward_meta.encoder_batch_map_cpu,
forward_meta.decoder_context_len_cpu, forward_meta.decoder_context_len_cpu,
forward_meta.decoder_batch_map_cpu, forward_meta.decoder_batch_map_cpu,
forward_meta.pos_emb_type,
self.rope_3d,
) )
return res return res

View File

@@ -72,7 +72,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
layer.top_k, layer.top_k,
False, # moe group, used in deepseek False, # moe group, used in deepseek
) )
if layer.tp_size > 1: if layer.reduce_results and layer.tp_size > 1:
from fastdeploy.distributed.communication import ( from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
@@ -252,7 +252,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
layer.top_k, layer.top_k,
False, # moe group, used in deepseek False, # moe group, used in deepseek
) )
if layer.tp_size > 1: if layer.reduce_results and layer.tp_size > 1:
from fastdeploy.distributed.communication import ( from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )

View File

@@ -31,6 +31,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
cuda_graph_buffers, cuda_graph_buffers,
support_graph_optimization, support_graph_optimization,
@@ -44,20 +45,15 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_Attention, Ernie4_5_Attention,
Ernie4_5_MLP, Ernie4_5_MLP,
) )
from fastdeploy.model_executor.models.ernie4_5_vl.image_op import (
text_image_gather_scatter,
text_image_index_out,
)
from fastdeploy.model_executor.models.model_base import ( from fastdeploy.model_executor.models.model_base import (
ModelCategory, ModelCategory,
ModelForCasualLM, ModelForCasualLM,
ModelRegistry, ModelRegistry,
) )
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
text_image_gather_scatter,
text_image_index_out,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
class Ernie4_5_VLMLP(Ernie4_5_MLP): class Ernie4_5_VLMLP(Ernie4_5_MLP):

View File

@@ -0,0 +1,32 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
text_image_gather_scatter,
text_image_index_out,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
text_image_gather_scatter,
text_image_index_out,
)
else:
raise ImportError("Unsupported platform, only support CUDA and XPU")
__all__ = ["text_image_gather_scatter", "text_image_index_out"]

View File

@@ -25,6 +25,7 @@ from paddle import nn
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import ( from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard, profile_run_guard,
@@ -34,10 +35,11 @@ from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionBackend,
) )
from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.xpu import ( from fastdeploy.model_executor.ops.xpu import (
adjust_batch, adjust_batch,
get_infer_param, get_infer_param,
@@ -201,6 +203,45 @@ def xpu_post_process(
update_inputs, update_inputs,
) )
# handle vl:
if model_output.enable_thinking:
exists_think_end = sampled_token_ids == model_output.think_end_id
paddle.assign(
paddle.where(
exists_think_end,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
)
paddle.assign(
paddle.where(
model_output.need_think_end.cast("bool"),
model_output.reasoning_index - 1,
model_output.reasoning_index,
),
model_output.reasoning_index,
)
stop_wo_think = (
(sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
| (model_output.reasoning_index == 0)
) & (model_output.need_think_end > 0)
sampled_token_ids = paddle.where(
stop_wo_think,
model_output.think_end_id,
sampled_token_ids,
)
paddle.assign(
paddle.where(
stop_wo_think,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
)
# 1. Set stop value # 1. Set stop value
paddle.assign( paddle.assign(
paddle.where( paddle.where(
@@ -340,11 +381,36 @@ class XPUModelRunner(ModelRunnerBase):
def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int): def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int):
super().__init__(fd_config=fd_config, device=device) super().__init__(fd_config=fd_config, device=device)
self.enable_mm = self.model_config.enable_mm
self.rank = rank self.rank = rank
self.local_rank = local_rank self.local_rank = local_rank
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
# VL model config:
if self.enable_mm:
self._init_image_preprocess()
self.amp_black = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
"sin",
"cos",
"sort",
"multinomial",
]
self.amp_white = [
"lookup_table",
"lookup_table_v2",
"flash_attn",
"matmul",
"matmul_v2",
"fused_gemm_epilogue",
]
# Sampler # Sampler
self.sampler = Sampler() # TODU(lilujia): sync with GPU
self.sampler = Sampler(fd_config)
# Lazy initialize kv cache after model loading # Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = [] # self.kv_caches: list[paddle.Tensor] = []
@@ -364,18 +430,28 @@ class XPUModelRunner(ModelRunnerBase):
).cpu() ).cpu()
# Initialize attention Backend # Initialize attention Backend
# Note(gonshaotian): Currently, all attention layers share one attention backend instance. # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
# In the future, we will expand it as a list. # In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = [] self.attn_backends: list[AttentionBackend] = []
self.initialize_attn_backend() self.initialize_attn_backend()
# Forward meta store the global meta information of the forward # Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None self.forward_meta: ForwardMeta = None
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def insert_tasks_v1(self, req_dicts: List[Request]): def insert_tasks_v1(self, req_dicts: List[Request]):
""" """
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
""" """
# NOTE(luotingdan): Lazy initialize kv cache # NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs: if "caches" not in self.share_inputs:
@@ -388,10 +464,53 @@ class XPUModelRunner(ModelRunnerBase):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index length = prefill_end_index - prefill_start_index
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end], dtype="uint8"
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
if inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
else:
position_ids = None
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
if len(request.output_token_ids) == 0:
input_ids = request.prompt_token_ids
else:
input_ids = request.prompt_token_ids + request.output_token_ids input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug( logger.debug(
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}" f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
@@ -475,41 +594,86 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0) request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( request.sampling_params.stop_seqs_len, dtype="int32"
request.get("stop_token_ids"), dtype="int64"
) )
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if has_prefill_task or has_decode_task: if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
def process_prefill_inputs(self, req_dicts: List[Request]): def insert_prefill_inputs(self, req_dicts: List[Request]):
"""Process inputs for prefill tasks and update share_inputs buffer""" """Process inputs for prefill tasks and update share_inputs buffer"""
req_len = len(req_dicts) req_len = len(req_dicts)
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
length = request.prompt_token_ids_len length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs["image_features"] = self.extract_vision_features(inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"]
else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
if res is not None:
return res
else:
return default_value
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["pre_ids"][idx : idx + 1] = -1 self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
request, "repetition_penalty", 1.0
)
self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request(
request, "frequency_penalty", 0.0
)
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
request, "presence_penalty", 0.0
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len "max_tokens", self.model_config.max_model_len
) )
@@ -540,11 +704,15 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0) request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( request.sampling_params.stop_seqs_len, dtype="int32"
request.get("stop_token_ids"), dtype="int64"
) )
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
@@ -565,6 +733,11 @@ class XPUModelRunner(ModelRunnerBase):
self.model_config.pad_token_id, self.model_config.pad_token_id,
dtype="int64", dtype="int64",
) )
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
@@ -627,7 +800,9 @@ class XPUModelRunner(ModelRunnerBase):
# Initialize rotary position embedding # Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
# TODO(gongshaotian): move to models # TODO(gongshaotian): move to models
if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope( self.share_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim, rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids, position_ids=tmp_position_ids,
@@ -654,18 +829,40 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32") self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
# Initialize stop seqs # Initialize stop seqs
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32") self.share_inputs["stop_seqs_len"] = paddle.full(
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
)
self.share_inputs["stop_seqs"] = paddle.full( self.share_inputs["stop_seqs"] = paddle.full(
[ [
max_num_seqs,
self.model_config.max_stop_seqs_num, self.model_config.max_stop_seqs_num,
self.model_config.stop_seqs_max_len, self.model_config.stop_seqs_max_len,
], ],
-1, -1,
dtype="int32", dtype="int64",
) )
if self.enable_mm:
head_dim = self.model_config.head_dim
self.share_inputs["rope_emb"] = paddle.full(
shape=[
max_num_seqs,
2,
1,
self.parallel_config.max_model_len,
1,
head_dim // 2,
],
fill_value=0,
dtype="float32",
)
self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
def _prepare_inputs(self, is_dummy_run=False) -> None: def _prepare_inputs(self, is_dummy_run=False) -> None:
"""prepare the model inputs""" """Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run: if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
recover_decode_task( recover_decode_task(
self.share_inputs["stop_flags"], self.share_inputs["stop_flags"],
@@ -689,10 +886,13 @@ class XPUModelRunner(ModelRunnerBase):
# Update bad tokens len # Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
if self.enable_mm: # pos_emb_type is different in EB and VL
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
self.forward_meta.attn_backend = self.attn_backends[0] self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend() self.initialize_attention_backend()
# Get sampling metadata # Get sampling metadata
# TODU(lilujia): sync with GPU
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
@@ -703,12 +903,16 @@ class XPUModelRunner(ModelRunnerBase):
seed=self.share_inputs["infer_seed"], seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
prompt_lens=self.share_inputs["prompt_lens"],
frequency_penalties=self.share_inputs["frequency_score"], frequency_penalties=self.share_inputs["frequency_score"],
presence_penalties=self.share_inputs["presence_score"], presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"], repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"], min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"], eos_token_ids=self.share_inputs["eos_token_id"],
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
) )
def load_model(self) -> None: def load_model(self) -> None:
@@ -723,7 +927,7 @@ class XPUModelRunner(ModelRunnerBase):
# 3. Load drafter model(for speculative decoding) # 3. Load drafter model(for speculative decoding)
def get_model(self) -> nn.Layer: def get_model(self) -> nn.Layer:
"""get current model""" """Get current model"""
return self.model return self.model
def initialize_attention_backend(self): def initialize_attention_backend(self):
@@ -741,6 +945,7 @@ class XPUModelRunner(ModelRunnerBase):
cache_kvs = {} cache_kvs = {}
max_block_num = self.num_gpu_blocks max_block_num = self.num_gpu_blocks
# Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None kv_cache_quant_type = None
@@ -800,33 +1005,6 @@ class XPUModelRunner(ModelRunnerBase):
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
def capture_model(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
"""
logger.warn("XPU not support cuda graph currently")
pass
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
for batch_size in self.sot_warmup_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=batch_size,
)
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int): def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
"""Set dummy prefill inputs to share_inputs""" """Set dummy prefill inputs to share_inputs"""
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10) full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
@@ -838,7 +1016,7 @@ class XPUModelRunner(ModelRunnerBase):
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
@@ -897,6 +1075,24 @@ class XPUModelRunner(ModelRunnerBase):
else: else:
paddle.device.xpu.set_debug_level(debug_level) paddle.device.xpu.set_debug_level(debug_level)
def capture_model(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
"""
logger.warn("XPU not support cuda graph currently")
pass
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
for batch_size in self.sot_warmup_sizes:
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
)
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def execute_model( def execute_model(
self, self,
model_forward_batch: Optional[List[Request]] = None, model_forward_batch: Optional[List[Request]] = None,
@@ -921,13 +1117,20 @@ class XPUModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda grph # 2. Padding inputs for cuda grph
# 3. Execute model # 3. Execute model
model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta) if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
hiddden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta) hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states) logits = self.model.compute_logits(hidden_states)
sampler_output = self.sampler(logits, self.sampling_metadata) sampler_output = self.sampler(logits, self.sampling_metadata)
# 5. Speculative decode # 5. Speculative decode
@@ -947,15 +1150,21 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"], is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=None,
msg_queue_id=self.parallel_config.msg_queue_id, msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank, mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep, use_ep=self.parallel_config.use_ep,
# 投机解码
full_hidden_states=None,
draft_tokens=None, draft_tokens=None,
actual_draft_token_num=None, actual_draft_token_num=None,
accept_tokens=None, accept_tokens=None,
accept_num=None, accept_num=None,
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )
xpu_post_process( xpu_post_process(
sampled_token_ids=sampler_output.sampled_token_ids, sampled_token_ids=sampler_output.sampled_token_ids,
@@ -984,13 +1193,43 @@ class XPUModelRunner(ModelRunnerBase):
@profile_run_guard(True) @profile_run_guard(True)
def profile_run(self) -> None: def profile_run(self) -> None:
"""Execute a forward pass with dummy inputs to profile the memory usage of the model.""" """Execute a forward pass with dummy inputs to profile the memory usage of the model"""
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self._dummy_run( self._dummy_run(
num_tokens=int(self.scheduler_config.max_num_batched_tokens), num_tokens=int(self.scheduler_config.max_num_batched_tokens),
batch_size=min(self.scheduler_config.max_num_seqs, 1), batch_size=min(self.scheduler_config.max_num_seqs, 1),
) )
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""
Set a globally unified block number and update the model's shared input.
Args:
num_gpu_blocks:
"""
self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num
self.initialize_kv_cache()
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
self.free_list_len = len(free_list)
self.share_inputs.update(
{
"free_list": paddle.to_tensor(free_list, dtype="int32"),
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
def clear_block_table(self) -> None: def clear_block_table(self) -> None:
""" """
Clear the block tables and kv cache after profiling. Clear the block tables and kv cache after profiling.
@@ -1025,41 +1264,135 @@ class XPUModelRunner(ModelRunnerBase):
byte_of_dtype = 2 byte_of_dtype = 2
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
required_memory = ( num_layers = self.model_config.num_hidden_layers
byte_of_dtype required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
* 2 # k + v
* (self.cache_config.block_size * hidden_dim)
* self.model_config.num_hidden_layers
)
return required_memory return required_memory
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""
Set a globally unified block number and update the model's shared input.
Args:
num_gpu_blocks:
"""
self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num
self.initialize_kv_cache()
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
self.free_list_len = len(free_list)
self.share_inputs.update(
{
"free_list": paddle.to_tensor(free_list, dtype="int32"),
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool:
""" """ """Stop decoding if the tensor meets the termination condition"""
return self.share_inputs["not_need_stop"][0] return self.share_inputs["not_need_stop"][0]
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata"""
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.model_config.model,
image_preprocessor_name=str(self.model_config.model),
)
processor.eval()
image_preprocess = processor.image_preprocessor
image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape(
[1, 3, 1, 1]
)
image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape(
[1, 3, 1, 1]
)
image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32")
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave(
self.model_config.vision_config.patch_size**2 * 1, -1
)
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave(
self.model_config.vision_config.patch_size**2 * 1, -1
)
self.image_preprocess = image_preprocess
def _preprocess_mm_task(self, one: dict) -> None:
"""process batch"""
input_ids = one["input_ids"][np.newaxis, :]
input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
images = paddle.to_tensor(images, dtype="uint8")
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
else:
image_type_ids = None
images = None
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0])
else:
position_ids = None
result = dict(
input_ids=input_ids,
image_type_ids=image_type_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
grid_thw=grid_thw,
images=images,
)
return result
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
images = images.cast("bfloat16")
token_type_ids = inputs["token_type_ids"]
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
# TODO(lulinjun): may need to check model_config and model_cfg
image_mask = input_ids == self.model_config.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.parallel_config.dtype,
):
image_features = self.model.vision_model.extract_feature(images, grid_thw)
if self.parallel_config.tensor_parallel_size > 1:
S, C = image_features.shape
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
image_features = self.model.resampler_model(
image_features,
image_mask,
token_type_ids_w_video,
image_type_ids,
grid_thw,
)
return image_features
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rotary_dim=self.model_config.head_dim,
partial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.parallel_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
)
return rope_emb

View File

@@ -51,12 +51,13 @@ class XpuWorker(WorkerBase):
"""Initialize device and Construct model runner""" """Initialize device and Construct model runner"""
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
# Set environment variable # Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"xpu:{self.local_rank}" self.device = f"xpu:{self.local_rank}"
paddle.device.set_device(self.device) paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype) paddle.set_default_dtype(self.parallel_config.dtype)
self.device_ids = self.parallel_config.device_ids.split(",")
gc.collect() gc.collect()
paddle.device.xpu.empty_cache()
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
@@ -69,12 +70,11 @@ class XpuWorker(WorkerBase):
local_rank=self.local_rank, local_rank=self.local_rank,
) )
def graph_optimize_and_warm_up_model(self) -> None: def exist_prefill(self):
""" """
Perform the warm-up and the graph optimization check whether prefill stage exist
""" """
if self.model_runner.graph_opt_level >= 1: return self.model_runner.exist_prefill()
self.model_runner.sot_warmup()
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
""" """
@@ -133,20 +133,17 @@ class XpuWorker(WorkerBase):
paddle.device.xpu.empty_cache() paddle.device.xpu.empty_cache()
return available_kv_cache_memory # approximate value return available_kv_cache_memory # approximate value
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
def load_model(self) -> None: def load_model(self) -> None:
""" """ """Load model"""
self.model_runner.load_model() self.model_runner.load_model()
def get_model(self) -> nn.Layer: def get_model(self) -> nn.Layer:
""" """ """Get current model"""
return self.model_runner.get_model() return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """ """Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model( def execute_model(
@@ -158,12 +155,6 @@ class XpuWorker(WorkerBase):
""" """ """ """
return self.model_runner.execute_model(model_forward_batch, num_running_requests, is_dummy_run) return self.model_runner.execute_model(model_forward_batch, num_running_requests, is_dummy_run)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int = -1) -> None: def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int = -1) -> None:
"""Process new requests and then start the decode loop """Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill, TODO(gongshaotian):The scheduler should schedule the handling of prefill,
@@ -172,8 +163,19 @@ class XpuWorker(WorkerBase):
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts) self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
else: else:
self.model_runner.process_prefill_inputs(req_dicts=req_dicts) self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """
return True return True
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()