[XPU] support XPU VL model inference (#4030)

* [XPU] support XPU VL model inference

* fix image op import and device check

* rebase develop

* fix perf
This commit is contained in:
Lucas
2025-09-25 14:34:15 +08:00
committed by GitHub
parent e36eccfdad
commit 87179cb744
18 changed files with 1300 additions and 146 deletions

View File

@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor &encoder_seq_lod_cpu,
const paddle::Tensor &encoder_batch_map_cpu,
const paddle::Tensor &decoder_context_len_cpu,
const paddle::Tensor &decoder_batch_map_cpu) {
const paddle::Tensor &decoder_batch_map_cpu,
const std::string &pos_emb_type="NORMAL",
bool rope_3d=false) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
if (rope_3d) {
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
} else {
rope_max_seqlen = rotary_embs.dims()[2];
}
auto block_attn_out =
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
prefix_lens_vp, // start_tokens
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
"HLD", "NORMAL",
"HLD", pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.slot_mapping_vp, // real_batch
param.batch_size, // batch_size
1, // emb_batch_size
rotary_embs.dims()[2], // max_seqlen TODO!!double check
rope_max_seqlen, // max_seqlen
param.head_num, param.kv_head_num, param.head_dim,
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
"NORMAL",
pos_emb_type,
!p_kcache_perhead_scale.defined()
? nullptr
: p_kcache_perhead_scale.data<float>() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
param.kv_head_num, // v_cache_scale_inv
nullptr, // k_cache_zp
nullptr, // v_cache_zp
false); // b_c8_pc
false, // b_c8_pc
rope_3d, // rope_3d
rope_3d_num_seqs);
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
// attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
"decoder_context_len_cpu",
"decoder_batch_map_cpu",
})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))

View File

@@ -0,0 +1,60 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
const paddle::Tensor& grid_thw,
const int64_t image_patch_id) {
// All tensor in cpu
auto input_ids_ptr = task_input_ids.data<int64_t>();
int64_t seq_lens_origin = task_input_ids.numel();
auto grid_thw_ptr = grid_thw.data<int64_t>();
int token_times = 4;
int token_idx = 0;
int image_idx = 0;
std::vector<int> img_boundaries, img_nums;
img_boundaries.emplace_back(0);
img_nums.emplace_back(0);
while (token_idx < seq_lens_origin) {
if (input_ids_ptr[token_idx] != image_patch_id) {
do {
token_idx++;
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
} else {
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
image_idx++;
token_idx += cur_image_token_len;
}
img_boundaries.emplace_back(token_idx);
img_nums.emplace_back(image_idx);
}
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
for (int i = 0; i < num_img_boundaries; i++) {
out.data<int64_t>()[i] = img_boundaries[i];
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
}
return {out};
}
PD_BUILD_OP(get_img_boundaries)
.Inputs({"task_input_ids", "grid_thw"})
.Attrs({"image_patch_id: int64_t"})
.Outputs({"img_boundaries"})
.SetKernelFn(PD_KERNEL(GetImgBoundaries));

View File

@@ -145,7 +145,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? up_gate_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
);
xdown_proj_w = std::make_shared<xftblock::Tensor>(
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
@@ -153,7 +154,8 @@ std::vector<paddle::Tensor> MoeLayerKernel(
? down_proj_weight_scale.get_ptr()->data<float>()
: nullptr),
xftblock_tw,
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
);
}
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include <xft/xdnn_plugin.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = input.dims()[0];
const int64_t hidden_size = input.dims()[1];
const int64_t text_token_num = text_input.dims()[0];
const int64_t image_token_num = image_input.dims()[0];
switch (input.type()) {
case paddle::DataType::BFLOAT16: {
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
typedef paddle::bfloat16 data_t;
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<XPUType*>(input.data<data_t>()),
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
reinterpret_cast<int*>(token_type_ids.data<int>()),
reinterpret_cast<int*>(text_index.data<int>()),
reinterpret_cast<int*>(image_index.data<int>()),
token_num,
text_token_num,
image_token_num,
hidden_size,
is_scatter
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
break;
}
default: {
PD_THROW(
"NOT supported data type. Only support BFLOAT16. ");
break;
}
}
}
PD_BUILD_OP(text_image_gather_scatter)
.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
void TextImageIndexOut(
const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
const paddle::Tensor& image_index) {
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int64_t token_num = token_type_ids.shape()[0];
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
token_type_ids.data<int32_t>(),
const_cast<int32_t*>(text_index.data<int32_t>()),
const_cast<int32_t*>(image_index.data<int32_t>()),
token_num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
}
PD_BUILD_OP(text_image_index_out)
.Inputs({"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_index_out",
"image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut));

View File

@@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
const TSCALE *scale_in, TY *y,
TSCALE *scale_out, int64_t m, int64_t n);
DLL_EXPORT int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
template <typename T>
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
/*--------------------------------------- MTP being --------------------------------------------*/

View File

@@ -0,0 +1,175 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
template <typename T>
static __device__ inline void text_image_gather(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(text_image_input + text_image_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, input + input_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
static __device__ inline void text_image_scatter(
__global_ptr__ T* input,
__global_ptr__ T* text_input,
__global_ptr__ T* image_input,
__global_ptr__ int* token_type_ids,
__global_ptr__ int* text_index,
__global_ptr__ int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
T* input_lm) {
int cid = core_id();
int clusterid = cluster_id();
int token_start_cluster;
int token_end_cluster;
int token_start_core;
int token_end_core;
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
// cluster partition
partition(cluster_id(), cluster_num(), (int)token_num, 1, &token_start_cluster, &token_end_cluster);
if (token_start_cluster >= token_end_cluster) {
return;
}
int rows_cluster = token_end_cluster - token_start_cluster; // total rows for a cluster
// core partition
partition(core_id(), core_num(), rows_cluster, 1, &token_start_core, &token_end_core);
int rows_core = token_end_core - token_start_core; // total rows for a core
token_start_core += token_start_cluster;
token_end_core += token_start_cluster;
int read_len;
for (int i = token_start_core; i < token_end_core; i += 1) {
int token_type, text_image_token_idx;
__global_ptr__ T* text_image_input = nullptr;
__global_ptr__ int* text_image_index = nullptr;
GM2LM(token_type_ids + i, &token_type, sizeof(int));
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else {
text_image_input = image_input;
text_image_index = image_index;
}
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int j = 0; j < hidden_size; j += BUFSIZE) {
read_len = min(hidden_size - j, BUFSIZE);
GM2LM(input + input_offset + j, input_lm, sizeof(T) * read_len);
LM2GM(input_lm, text_image_input + text_image_offset + j, sizeof(T) * read_len);
}
}
}
template <typename T>
__global__ void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
const int BUFSIZE = 2 * 1024 / sizeof(T); // 1024 for bf16, 512 for fp32
__simd__ T input_lm[BUFSIZE]; // 2KB for bf16 and fp32
if (is_scatter) {
text_image_scatter(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
} else {
text_image_gather(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, input_lm);
}
}
#define _XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(T) \
template __global__ void text_image_gather_scatter<T>( \
T* input, \
T* text_input, \
T* image_input, \
int* token_type_ids, \
int* text_index, \
int* image_index, \
int64_t token_num, \
int64_t text_token_num, \
int64_t image_token_num, \
int64_t hidden_size, \
bool is_scatter);
_XPU_DEF_TEXT_IMAGE_GATHER_SCATTER(bfloat16);
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,97 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2025 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/cluster_primitive_template.h"
namespace xpu3 {
namespace plugin {
static __device__ void do_calc(const _shared_ptr_ int* lm_x, int* lm_y1, int* lm_y2, int64_t size, int& text_count, int& images_count) {
for (int j = 0; j < size; j++) {
if (lm_x[j] == 0) {
lm_y1[j] = text_count;
text_count += 1;
} else {
lm_y2[j] = images_count;
images_count += 1;
}
}
mfence_lm_sm();
}
__global__ void text_image_index_out_kernel(
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
const int cid = core_id();
const int tid = core_id() * cluster_num() + cluster_id();
const int nthreads = core_num() * cluster_num();
if (tid >= 1) return;
constexpr int BUFSIZE = 1024;
constexpr int READ_MAX_SIZE = BUFSIZE / sizeof(int);
const int64_t len = token_num;
__simd__ char buffer0[BUFSIZE * 3];
__simd__ char buffer1[BUFSIZE * 3];
__simd__ __shared__ char buffer2[64][BUFSIZE * 2];
DoublePtr<READ_MAX_SIZE, SmPtr<int>> buffer_ptr_x((SmPtr<int>((_shared_ptr_ int*)buffer2[cid])));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y1((LmPtr<int>((int*)buffer0)));
TriplePtr<READ_MAX_SIZE, LmPtr<int>> buffer_ptr_y2((LmPtr<int>((int*)buffer1)));
int64_t buflen = get_1d_buflen(len, nthreads, READ_MAX_SIZE, 64);
int64_t i = tid * buflen;
int read_size = 0;
int offset = nthreads * buflen;
int text_count = 0;
int images_count = 0;
if (i < len) {
read_size = min<int64_t>(buflen, len - i);
buffer_ptr_y1.gm_load_async(text_index + tid * buflen, read_size);
buffer_ptr_y2.gm_load_async(image_index + tid * buflen, read_size);
buffer_ptr_x.gm_load_async(token_type_ids + tid * buflen, read_size);
mfence();
}
while (i < len && i + offset < len) {
i = i + offset;
int read_size_next = min<int64_t>(buflen, len - i);
buffer_ptr_x.next().gm_load_async(token_type_ids + i, read_size_next);
buffer_ptr_y1.next().gm_load_async(text_index + i, read_size_next);
buffer_ptr_y2.next().gm_load_async(image_index + i, read_size_next);
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i - offset, read_size);
buffer_ptr_y2.gm_store_async(image_index + i - offset, read_size);
buffer_ptr_x.toggle();
buffer_ptr_y1.toggle();
buffer_ptr_y2.toggle();
read_size = read_size_next;
}
if (i < len) {
do_calc(buffer_ptr_x.ptr, buffer_ptr_y1.ptr, buffer_ptr_y2.ptr, read_size, text_count, images_count);
buffer_ptr_y1.gm_store_async(text_index + i, read_size);
buffer_ptr_y2.gm_store(image_index + i, read_size);
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,182 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
template <typename T>
__attribute__((global)) void text_image_gather_scatter(
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
if (is_scatter) {
// Scatter mode: input -> text_input/image_input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = input[input_offset + j];
text_image_input[text_image_offset + j] = value;
}
}
} else {
// Gather mode: text_input/image_input -> input
for (int64_t i = 0; i < token_num; i++) {
int token_type = token_type_ids[i];
T* text_image_input = nullptr;
int* text_image_index = nullptr;
if (token_type == 0) {
text_image_input = text_input;
text_image_index = text_index;
} else { // token_type == 1
text_image_input = image_input;
text_image_index = image_index;
}
int text_image_token_idx = text_image_index[i];
int input_offset = i * hidden_size;
int text_image_offset = text_image_token_idx * hidden_size;
for (int64_t j = 0; j < hidden_size; j++) {
T value = text_image_input[text_image_offset + j];
input[input_offset + j] = value;
}
}
}
return api::SUCCESS;
}
template <typename T>
static int xpu3_wrapper(
Context* ctx,
T* input,
T* text_input,
T* image_input,
int* token_type_ids,
int* text_index,
int* image_index,
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
xpu3::plugin::text_image_gather_scatter<T> <<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
return api::SUCCESS;
}
template <typename T>
int text_image_gather_scatter(
Context* ctx,
T* input, // shape [token_num, hidden_size]
T* text_input, // shape [text_token_num, hidden_size]
T* image_input, // shape [image_token_num, hidden_size]
int* token_type_ids,// shape [token_num], 0 for text, 1 for image
int* text_index, // shape [token_num], mapping from input to text_input
int* image_index, // shape [token_num], mapping from input to image_input
int64_t token_num,
int64_t text_token_num,
int64_t image_token_num,
int64_t hidden_size,
bool is_scatter) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_gather_scatter", T);
WRAPPER_DUMP_PARAM6(ctx, input, text_input, image_input, token_type_ids, text_index, image_index);
WRAPPER_DUMP_PARAM5(ctx, token_num, text_token_num, image_token_num, hidden_size, is_scatter);
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_PTR(ctx, T, token_num * hidden_size, input);
if (text_token_num != 0) { // avoiding text_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, text_token_num * hidden_size, text_input);
}
if (image_token_num != 0) { // avoiding image_input tensor with shape [0, hidden_size]
WRAPPER_CHECK_PTR(ctx, T, image_token_num * hidden_size, image_input);
}
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
WRAPPER_ASSERT_EQ(ctx, token_num, text_token_num + image_token_num);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper<T>(
ctx, input, text_input, image_input, token_type_ids, text_index, image_index,
token_num, text_token_num, image_token_num, hidden_size, is_scatter
);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int text_image_gather_scatter(
Context*, bfloat16*, bfloat16*, bfloat16*, int*, int*, int*, const int64_t, const int64_t, const int64_t, const int64_t, bool);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,103 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void text_image_index_out_kernel(const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
int text_count = 0;
int image_count = 0;
for (int64_t i = 0; i < token_num; ++i) {
if (token_type_ids[i] == 0) {
text_index[i] = text_count;
++text_count;
} else {
image_index[i] = image_count;
++image_count;
}
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
xpu3::plugin::text_image_index_out_kernel<<<1, 1, ctx->xpu_stream>>>(
token_type_ids,
text_index,
image_index,
token_num);
return api::SUCCESS;
}
int text_image_index_out(Context* ctx,
const int* token_type_ids, // x
int* text_index, // y1
int* image_index, // y2
const int64_t token_num) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "text_image_index_out", int);
WRAPPER_DUMP_PARAM4(
ctx, token_type_ids, text_index, image_index, token_num);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, token_num, 0);
WRAPPER_CHECK_PTR(ctx, int, token_num, token_type_ids);
WRAPPER_CHECK_PTR(ctx, int, token_num, text_index);
WRAPPER_CHECK_PTR(ctx, int, token_num, image_index);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
} else if (ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
token_type_ids,
text_index,
image_index,
token_num);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -30,6 +30,7 @@ import paddle
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger
@@ -157,6 +158,7 @@ class ResourceManagerV1(ResourceManager):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
num_new_tokens = min(num_new_tokens, token_budget)
request.with_image = False
if not self.config.model_config.enable_mm:
return num_new_tokens
@@ -219,6 +221,9 @@ class ResourceManagerV1(ResourceManager):
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_img_boundaries
else:
from fastdeploy.model_executor.ops.gpu import get_img_boundaries
request.multimodal_img_boundaries = get_img_boundaries(

View File

@@ -232,6 +232,8 @@ class XPUForwardMeta(ForwardMeta):
dec_batch: Optional[paddle.Tensor] = None
#
total_enc_len: Optional[paddle.Tensor] = None
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
pos_emb_type: Optional[str] = "NORMAL"
@dataclass

View File

@@ -183,5 +183,7 @@ class XPUAttentionBackend(AttentionBackend):
forward_meta.encoder_batch_map_cpu,
forward_meta.decoder_context_len_cpu,
forward_meta.decoder_batch_map_cpu,
forward_meta.pos_emb_type,
self.rope_3d,
)
return res

View File

@@ -72,7 +72,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
layer.top_k,
False, # moe group, used in deepseek
)
if layer.tp_size > 1:
if layer.reduce_results and layer.tp_size > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce,
)
@@ -252,7 +252,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
layer.top_k,
False, # moe group, used in deepseek
)
if layer.tp_size > 1:
if layer.reduce_results and layer.tp_size > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce,
)

View File

@@ -31,6 +31,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
cuda_graph_buffers,
support_graph_optimization,
@@ -44,20 +45,15 @@ from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_Attention,
Ernie4_5_MLP,
)
from fastdeploy.model_executor.models.ernie4_5_vl.image_op import (
text_image_gather_scatter,
text_image_index_out,
)
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
text_image_gather_scatter,
text_image_index_out,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
class Ernie4_5_VLMLP(Ernie4_5_MLP):

View File

@@ -0,0 +1,32 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
text_image_gather_scatter,
text_image_index_out,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
text_image_gather_scatter,
text_image_index_out,
)
else:
raise ImportError("Unsupported platform, only support CUDA and XPU")
__all__ = ["text_image_gather_scatter", "text_image_index_out"]

View File

@@ -25,6 +25,7 @@ from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
@@ -34,10 +35,11 @@ from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
get_infer_param,
@@ -201,6 +203,45 @@ def xpu_post_process(
update_inputs,
)
# handle vl:
if model_output.enable_thinking:
exists_think_end = sampled_token_ids == model_output.think_end_id
paddle.assign(
paddle.where(
exists_think_end,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
)
paddle.assign(
paddle.where(
model_output.need_think_end.cast("bool"),
model_output.reasoning_index - 1,
model_output.reasoning_index,
),
model_output.reasoning_index,
)
stop_wo_think = (
(sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
| (model_output.reasoning_index == 0)
) & (model_output.need_think_end > 0)
sampled_token_ids = paddle.where(
stop_wo_think,
model_output.think_end_id,
sampled_token_ids,
)
paddle.assign(
paddle.where(
stop_wo_think,
model_output.need_think_end - 1,
model_output.need_think_end,
),
model_output.need_think_end,
)
# 1. Set stop value
paddle.assign(
paddle.where(
@@ -340,11 +381,36 @@ class XPUModelRunner(ModelRunnerBase):
def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int):
super().__init__(fd_config=fd_config, device=device)
self.enable_mm = self.model_config.enable_mm
self.rank = rank
self.local_rank = local_rank
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
# VL model config:
if self.enable_mm:
self._init_image_preprocess()
self.amp_black = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
"sin",
"cos",
"sort",
"multinomial",
]
self.amp_white = [
"lookup_table",
"lookup_table_v2",
"flash_attn",
"matmul",
"matmul_v2",
"fused_gemm_epilogue",
]
# Sampler
self.sampler = Sampler()
# TODU(lilujia): sync with GPU
self.sampler = Sampler(fd_config)
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
@@ -364,18 +430,28 @@ class XPUModelRunner(ModelRunnerBase):
).cpu()
# Initialize attention Backend
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
# In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = []
self.initialize_attn_backend()
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def insert_tasks_v1(self, req_dicts: List[Request]):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
@@ -388,10 +464,53 @@ class XPUModelRunner(ModelRunnerBase):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if self.enable_mm:
inputs = request.multimodal_inputs
if request.with_image:
vision_inputs = {}
vision_inputs["input_ids"] = paddle.to_tensor(
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["token_type_ids"] = paddle.to_tensor(
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
)
vision_inputs["image_type_ids"] = paddle.to_tensor(
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
dtype=paddle.int64,
)
vision_inputs["images"] = paddle.to_tensor(
inputs["images"][request.image_start : request.image_end], dtype="uint8"
)
vision_inputs["grid_thw"] = paddle.to_tensor(
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
)
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
else:
self.share_inputs["image_features"] = None
if inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64",
).unsqueeze([0])
else:
position_ids = None
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
if len(request.output_token_ids) == 0:
input_ids = request.prompt_token_ids
else:
input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug(
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
@@ -475,41 +594,86 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
request.sampling_params.stop_seqs_len, dtype="int32"
)
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
def process_prefill_inputs(self, req_dicts: List[Request]):
def insert_prefill_inputs(self, req_dicts: List[Request]):
"""Process inputs for prefill tasks and update share_inputs buffer"""
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
length = request.prompt_token_ids_len
length = len(request.prompt_token_ids)
assert length > 0, "The prompt requested must not be empty."
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs["image_features"] = self.extract_vision_features(inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"]
else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length
if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048)
)
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
if res is not None:
return res
else:
return default_value
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
request, "repetition_penalty", 1.0
)
self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request(
request, "frequency_penalty", 0.0
)
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
request, "presence_penalty", 0.0
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len
)
@@ -540,11 +704,15 @@ class XPUModelRunner(ModelRunnerBase):
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
request.sampling_params.stop_seqs_len, dtype="int32"
)
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.share_inputs["not_need_stop"][0] = True
@@ -565,6 +733,11 @@ class XPUModelRunner(ModelRunnerBase):
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len],
self.model_config.pad_token_id,
dtype="int64",
)
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
@@ -627,7 +800,9 @@ class XPUModelRunner(ModelRunnerBase):
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
# TODO(gongshaotian): move to models
if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
@@ -654,18 +829,40 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
# Initialize stop seqs
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32")
self.share_inputs["stop_seqs_len"] = paddle.full(
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
)
self.share_inputs["stop_seqs"] = paddle.full(
[
max_num_seqs,
self.model_config.max_stop_seqs_num,
self.model_config.stop_seqs_max_len,
],
-1,
dtype="int32",
dtype="int64",
)
if self.enable_mm:
head_dim = self.model_config.head_dim
self.share_inputs["rope_emb"] = paddle.full(
shape=[
max_num_seqs,
2,
1,
self.parallel_config.max_model_len,
1,
head_dim // 2,
],
fill_value=0,
dtype="float32",
)
self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
def _prepare_inputs(self, is_dummy_run=False) -> None:
"""prepare the model inputs"""
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
recover_decode_task(
self.share_inputs["stop_flags"],
@@ -689,10 +886,13 @@ class XPUModelRunner(ModelRunnerBase):
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
if self.enable_mm: # pos_emb_type is different in EB and VL
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()
# Get sampling metadata
# TODU(lilujia): sync with GPU
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
@@ -703,12 +903,16 @@ class XPUModelRunner(ModelRunnerBase):
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
prompt_lens=self.share_inputs["prompt_lens"],
frequency_penalties=self.share_inputs["frequency_score"],
presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"],
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
)
def load_model(self) -> None:
@@ -723,7 +927,7 @@ class XPUModelRunner(ModelRunnerBase):
# 3. Load drafter model(for speculative decoding)
def get_model(self) -> nn.Layer:
"""get current model"""
"""Get current model"""
return self.model
def initialize_attention_backend(self):
@@ -741,6 +945,7 @@ class XPUModelRunner(ModelRunnerBase):
cache_kvs = {}
max_block_num = self.num_gpu_blocks
# Get kv cache dtype
cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
@@ -800,33 +1005,6 @@ class XPUModelRunner(ModelRunnerBase):
)
self.attn_backends.append(attn_backend)
def capture_model(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
"""
logger.warn("XPU not support cuda graph currently")
pass
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
for batch_size in self.sot_warmup_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=batch_size,
)
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
"""Set dummy prefill inputs to share_inputs"""
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
@@ -838,7 +1016,7 @@ class XPUModelRunner(ModelRunnerBase):
for i in range(batch_size):
idx = i
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
@@ -897,6 +1075,24 @@ class XPUModelRunner(ModelRunnerBase):
else:
paddle.device.xpu.set_debug_level(debug_level)
def capture_model(self) -> None:
"""
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
"""
logger.warn("XPU not support cuda graph currently")
pass
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
for batch_size in self.sot_warmup_sizes:
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
)
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
@@ -921,13 +1117,20 @@ class XPUModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda grph
# 3. Execute model
model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta)
if self.enable_mm:
model_output = self.model(
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
hiddden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states)
logits = self.model.compute_logits(hidden_states)
sampler_output = self.sampler(logits, self.sampling_metadata)
# 5. Speculative decode
@@ -947,15 +1150,21 @@ class XPUModelRunner(ModelRunnerBase):
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
is_block_step=self.share_inputs["is_block_step"],
# 投机解码
full_hidden_states=None,
msg_queue_id=self.parallel_config.msg_queue_id,
mp_rank=self.local_rank,
use_ep=self.parallel_config.use_ep,
# 投机解码
full_hidden_states=None,
draft_tokens=None,
actual_draft_token_num=None,
accept_tokens=None,
accept_num=None,
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
)
xpu_post_process(
sampled_token_ids=sampler_output.sampled_token_ids,
@@ -984,13 +1193,43 @@ class XPUModelRunner(ModelRunnerBase):
@profile_run_guard(True)
def profile_run(self) -> None:
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self._dummy_run(
num_tokens=int(self.scheduler_config.max_num_batched_tokens),
batch_size=min(self.scheduler_config.max_num_seqs, 1),
)
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""
Set a globally unified block number and update the model's shared input.
Args:
num_gpu_blocks:
"""
self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num
self.initialize_kv_cache()
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
self.free_list_len = len(free_list)
self.share_inputs.update(
{
"free_list": paddle.to_tensor(free_list, dtype="int32"),
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
def clear_block_table(self) -> None:
"""
Clear the block tables and kv cache after profiling.
@@ -1025,41 +1264,135 @@ class XPUModelRunner(ModelRunnerBase):
byte_of_dtype = 2
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
required_memory = (
byte_of_dtype
* 2 # k + v
* (self.cache_config.block_size * hidden_dim)
* self.model_config.num_hidden_layers
)
num_layers = self.model_config.num_hidden_layers
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""
Set a globally unified block number and update the model's shared input.
Args:
num_gpu_blocks:
"""
self.num_gpu_blocks = num_gpu_blocks
# Reset block table and kv cache with global block num
self.initialize_kv_cache()
# Reset free list
free_list = list(
range(
self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1,
)
)
self.free_list_len = len(free_list)
self.share_inputs.update(
{
"free_list": paddle.to_tensor(free_list, dtype="int32"),
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
def not_need_stop(self) -> bool:
""" """
"""Stop decoding if the tensor meets the termination condition"""
return self.share_inputs["not_need_stop"][0]
def clear_cache(self):
"""Clear cached data from shared inputs and forward metadata"""
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
self.forward_meta.clear_caches()
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.model_config.model,
image_preprocessor_name=str(self.model_config.model),
)
processor.eval()
image_preprocess = processor.image_preprocessor
image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape(
[1, 3, 1, 1]
)
image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape(
[1, 3, 1, 1]
)
image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32")
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave(
self.model_config.vision_config.patch_size**2 * 1, -1
)
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave(
self.model_config.vision_config.patch_size**2 * 1, -1
)
self.image_preprocess = image_preprocess
def _preprocess_mm_task(self, one: dict) -> None:
"""process batch"""
input_ids = one["input_ids"][np.newaxis, :]
input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
images = paddle.to_tensor(images, dtype="uint8")
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
else:
image_type_ids = None
images = None
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0])
else:
position_ids = None
result = dict(
input_ids=input_ids,
image_type_ids=image_type_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
grid_thw=grid_thw,
images=images,
)
return result
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
images = images.cast("bfloat16")
token_type_ids = inputs["token_type_ids"]
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
# TODO(lulinjun): may need to check model_config and model_cfg
image_mask = input_ids == self.model_config.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.parallel_config.dtype,
):
image_features = self.model.vision_model.extract_feature(images, grid_thw)
if self.parallel_config.tensor_parallel_size > 1:
S, C = image_features.shape
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
image_features = self.model.resampler_model(
image_features,
image_mask,
token_type_ids_w_video,
image_type_ids,
grid_thw,
)
return image_features
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1),
[1, 1, 3],
)
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rotary_dim=self.model_config.head_dim,
partial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.parallel_config.max_model_len,
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
model_type=self.model_config.model_type,
)
return rope_emb

View File

@@ -51,12 +51,13 @@ class XpuWorker(WorkerBase):
"""Initialize device and Construct model runner"""
if paddle.is_compiled_with_xpu():
# Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"xpu:{self.local_rank}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
self.device_ids = self.parallel_config.device_ids.split(",")
gc.collect()
paddle.device.xpu.empty_cache()
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
@@ -69,12 +70,11 @@ class XpuWorker(WorkerBase):
local_rank=self.local_rank,
)
def graph_optimize_and_warm_up_model(self) -> None:
def exist_prefill(self):
"""
Perform the warm-up and the graph optimization
check whether prefill stage exist
"""
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
@@ -133,20 +133,17 @@ class XpuWorker(WorkerBase):
paddle.device.xpu.empty_cache()
return available_kv_cache_memory # approximate value
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
def load_model(self) -> None:
""" """
"""Load model"""
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
"""Get current model"""
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model(
@@ -158,12 +155,6 @@ class XpuWorker(WorkerBase):
""" """
return self.model_runner.execute_model(model_forward_batch, num_running_requests, is_dummy_run)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int = -1) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
@@ -172,8 +163,19 @@ class XpuWorker(WorkerBase):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
else:
self.model_runner.process_prefill_inputs(req_dicts=req_dicts)
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()