[XPU] Support kvblock centralized management (#3017)

This commit is contained in:
yinwei
2025-07-29 10:40:55 +08:00
committed by GitHub
parent 286802a070
commit f2a528f9ae
10 changed files with 843 additions and 21 deletions

View File

@@ -0,0 +1,68 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
int r = baidu::xpu::api::plugin::recover_decode_task(
xpu_ctx->x_context(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::recover_decode_task failed.");
}
PD_BUILD_OP(recover_decode_task)
.Inputs({"stop_flags",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_seq_lens_decoder",
"block_tables",
"is_block_step"})
.Attrs({"block_size: int"})
.Outputs({"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"stop_flags_out",
"is_block_step_out"})
.SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"stop_flags", "stop_flags_out"},
{"is_block_step", "is_block_step_out"}})
.SetKernelFn(PD_KERNEL(RecoverDecodeTask));

View File

@@ -0,0 +1,105 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
void UpdateInputesV1(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &topk_ids,
const paddle::Tensor &input_ids,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step,
const int block_size) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx =
paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
// std::cout << "now_bsz: " << now_bsz << std::endl;
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
int r = baidu::xpu::api::plugin::update_inputs_v1(
xpu_ctx->x_context(),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(prompt_lens.data<int64_t>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
const_cast<int *>(block_tables.data<int>()),
stop_nums.data<int64_t>(),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<bool *>(is_block_step.data<bool>()),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
PD_CHECK(r == 0, "baidu::xpu::api::plugin::update_inputs_kernel_v1 failed.");
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_OP(update_inputs_v1)
.Inputs({"stop_flags",
"not_need_stop",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_seq_lens_decoder",
"prompt_lens",
"topk_ids",
"input_ids",
"block_tables",
"stop_nums",
"next_tokens",
"is_block_step"})
.Attrs({"block_size: int"})
.Outputs({"not_need_stop_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"topk_ids_out",
"input_ids_out",
"stop_flags_out",
"is_block_step_out"})
.SetInplaceMap({{"not_need_stop", "not_need_stop_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"topk_ids", "topk_ids_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"is_block_step", "is_block_step_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesV1));

View File

@@ -86,6 +86,39 @@ recover_block(Context *ctx,
const int block_num_per_seq, const int length, const int block_num_per_seq, const int length,
const int pre_id_length); const int pre_id_length);
DLL_EXPORT int
recover_decode_task(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
DLL_EXPORT int
update_inputs_v1(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
template <typename TX, typename TY> template <typename TX, typename TY>
DLL_EXPORT int DLL_EXPORT int
eb_adjust_batch(Context *ctx, const TX *x, TY *y, eb_adjust_batch(Context *ctx, const TX *x, TY *y,

View File

@@ -0,0 +1,41 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu3 {
namespace plugin {
__global__ void recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int thread_idx = clusterid * ncores + cid;
int nthreads = nclusters * ncores;
// if (clusterid != 0) return;
for (; thread_idx < bsz; thread_idx += nthreads) {
if(is_block_step[thread_idx] == true) {
// int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (block_tables[thread_idx * block_num_per_seq + step_seq_lens_decoder[thread_idx] / block_size] != -1) {
// can be recovered for decoding
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx]= 1;
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,131 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
// #include <stdio.h>
// using namespace std;
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk.h"
namespace xpu3 {
namespace plugin {
__global__ void update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
// std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl;
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
int thread_idx = clusterid * ncores + cid;
if (clusterid != 0) return;
const int max_bs = 1024;
__shared__ bool stop_flags_sm[max_bs];
__shared__ int stop_flags_int_sm[max_bs];
if(cid == 0){
GM2SM(stop_flags, stop_flags_sm, sizeof(bool) * bsz);
}
sync_all();
for(int i = cid; i < bsz; i+= ncores){
if(i < bsz){
stop_flags_sm[i] = stop_flags[i];
stop_flags_int_sm[i] = static_cast<int64_t>(stop_flags_sm[i]);
}else{
stop_flags_sm[i] = true;
stop_flags_int_sm[i] = 1;
}
if(i<bsz){
int seq_len_this_time_update = 0;
int seq_len_decoder_update = 0;
int seq_lens_encoder_update = 0;
if(stop_flags_sm[i]){
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
}else{
GM2LM(seq_lens_this_time+i, &seq_len_this_time_update, sizeof(int));
GM2LM(seq_lens_decoder+i, &seq_len_decoder_update, sizeof(int));
GM2LM(seq_lens_encoder+i, &seq_lens_encoder_update, sizeof(int));
int sum_of_seq_lens_this_time_and_seq_lens_decoder = seq_len_this_time_update + seq_len_decoder_update;
int prompt_lens_update = 0;
GM2LM(prompt_lens+i, &prompt_lens_update, sizeof(int64_t));
// decoding
if(sum_of_seq_lens_this_time_and_seq_lens_decoder >= prompt_lens_update){
seq_len_decoder_update = seq_len_this_time_update + seq_len_decoder_update;
LM2GM(&seq_len_decoder_update, seq_lens_decoder+i, sizeof(int));
seq_len_this_time_update = 1;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_lens_encoder_update = 0;
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t input_ids_update;
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
LM2GM(&input_ids_update, input_ids + i * input_ids_stride, sizeof(int64_t));
// to judge whether block is not enough
if(seq_len_this_time_update != 0 && block_tables[i * block_num_per_seq + seq_len_decoder_update/block_size] == -1){
is_block_step[i] = true;
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
LM2GM(&seq_len_decoder_update, step_seq_lens_decoder+i, sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
stop_flags_int_sm[i] = 1;
}
}else{
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm+i, stop_flags+i, sizeof(bool));
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_len_decoder_update = 0;
seq_lens_encoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t topk_ids_update = -1;
LM2GM(&topk_ids_update, topk_ids + i, sizeof(int64_t));
stop_flags_int_sm[i] = 1;
}
}
}
}
sync_all();
sync_cluster();
int stop_sum = 0;
if (cid == 0) {
for (int i = 0; i < max_bsz; i++) {
stop_sum += stop_flags_int_sm[i];
}
// printf("stop_sum : %d\n", stop_sum);
int64_t stop_num;
GM2LM(stop_nums, &stop_num, sizeof(int64_t));
bool not_need_stop_update = stop_sum < static_cast<int>(stop_num);
mfence_lm();
LM2GM(&not_need_stop_update, not_need_stop, sizeof(bool));
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,107 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
recover_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto recover_decode_task = xpu3::plugin::recover_decode_task;
recover_decode_task<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
return api::SUCCESS;
}
int recover_decode_task(Context *ctx, bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
const int bsz,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "recover_decode_task", int);
WRAPPER_DUMP_PARAM5(ctx, stop_flags, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
WRAPPER_DUMP_PARAM2(ctx, block_tables, is_block_step);
WRAPPER_DUMP_PARAM3(ctx, bsz, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
block_tables,
is_block_step,
bsz,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,149 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include <algorithm>
#include <numeric>
namespace xpu3 {
namespace plugin {
__attribute__((global)) void
update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
// kernel 内要做 reduce只能用 1 个 cluster
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(prompt_lens),
reinterpret_cast<XPU_INT64 *>(topk_ids),
reinterpret_cast<XPU_INT64 *>(input_ids),
block_tables,
reinterpret_cast<const XPU_INT64 *>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
return api::SUCCESS;
}
int update_inputs_v1(Context *ctx, bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "update_inputs_v1", int);
WRAPPER_DUMP_PARAM5(ctx, not_need_stop, seq_lens_this_time,
seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder);
WRAPPER_DUMP_PARAM5(ctx, prompt_lens, topk_ids, input_ids, block_tables, stop_nums);
WRAPPER_DUMP_PARAM3(ctx, stop_flags, is_block_step, next_tokens);
WRAPPER_DUMP_PARAM5(ctx, bsz, max_bsz, input_ids_stride, block_num_per_seq, block_size);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx, not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
prompt_lens,
topk_ids,
input_ids,
block_tables,
stop_nums,
stop_flags,
is_block_step,
next_tokens,
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -144,6 +144,8 @@ def xpu_setup_ops():
"./ops/get_token_penalty_multi_scores.cc", "./ops/get_token_penalty_multi_scores.cc",
"./ops/get_padding_offset.cc", "./ops/get_padding_offset.cc",
"./ops/update_inputs.cc", "./ops/update_inputs.cc",
"./ops/recover_decode_task.cc",
"./ops/update_inputs_v1.cc",
"./ops/get_output.cc", "./ops/get_output.cc",
"./ops/step.cc", "./ops/step.cc",
"./ops/get_infer_param.cc", "./ops/get_infer_param.cc",

View File

@@ -22,8 +22,9 @@ import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -33,6 +34,13 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
get_infer_param,
get_padding_offset,
recover_decode_task,
update_inputs_v1,
)
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -53,11 +61,6 @@ def xpu_pre_process(
max_len = input_ids.shape[1] max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time) cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time)
from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
get_infer_param,
get_padding_offset,
)
( (
ids_remove_padding, ids_remove_padding,
@@ -111,6 +114,18 @@ def xpu_pre_process(
) = get_infer_param(seq_lens_encoder, seq_lens_decoder) ) = get_infer_param(seq_lens_encoder, seq_lens_decoder)
# Adjust batch # Adjust batch
# print(f"=========================adjust_batch 更新前=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")
# print(f"xpu_forward_meta.dec_batch : {xpu_forward_meta.decoder_batch_map}")
adjusted_input = adjust_batch( adjusted_input = adjust_batch(
ids_remove_padding.reshape([-1, 1]), ids_remove_padding.reshape([-1, 1]),
cum_offsets, cum_offsets,
@@ -125,6 +140,17 @@ def xpu_pre_process(
None, # output_padding_offset None, # output_padding_offset
-1, # max_input_length -1, # max_input_length
) )
# print(f"=========================adjust_batch 更新后=========================")
# print(f"ids_remove_padding : {ids_remove_padding}")
# print(f"cum_offsets : {cum_offsets}")
# print(f"xpu_forward_meta.encoder_seq_lod : {xpu_forward_meta.encoder_seq_lod}")
# print(f"xpu_forward_meta.encoder_batch_idx: {xpu_forward_meta.encoder_batch_idx}")
# print(f"xpu_forward_meta.decoder_batch_idx : {xpu_forward_meta.decoder_batch_idx}")
# print(f"xpu_forward_meta.encoder_seq_lod_cpu : {xpu_forward_meta.encoder_seq_lod_cpu}")
# print(f"xpu_forward_meta.encoder_batch_idx_cpu : {xpu_forward_meta.encoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.decoder_batch_idx_cpu : {xpu_forward_meta.decoder_batch_idx_cpu}")
# print(f"xpu_forward_meta.enc_batch : {xpu_forward_meta.encoder_batch_map}")
adjusted_input = adjusted_input.squeeze(1) adjusted_input = adjusted_input.squeeze(1)
share_inputs["ids_remove_padding"] = adjusted_input share_inputs["ids_remove_padding"] = adjusted_input
@@ -160,7 +186,9 @@ def xpu_process_output(
def xpu_post_process( def xpu_post_process(
sampled_token_ids: paddle.Tensor, sampled_token_ids: paddle.Tensor,
model_output: ModelOutputData, model_output: ModelOutputData,
skip_save_output: bool, share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
) -> None: ) -> None:
""" """ """ """
from fastdeploy.model_executor.ops.xpu import ( from fastdeploy.model_executor.ops.xpu import (
@@ -194,17 +222,66 @@ def xpu_post_process(
# 2. Update the input buffer of the model # 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff(): with paddle.framework._no_check_dy2st_diff():
update_inputs( if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output:
model_output.stop_flags,
model_output.not_need_stop, # print(f"============================================update_inputs_v1 更新前=========================================")
model_output.seq_lens_this_time, # print(f"model_output.stop_flags : {model_output.stop_flags}")
model_output.seq_lens_encoder, # print(f"model_output.not_need_stop : {model_output.not_need_stop}")
model_output.seq_lens_decoder, # print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
model_output.input_ids, # print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
model_output.stop_nums, # print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
sampled_token_ids, # print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
model_output.is_block_step, # print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
) # print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.stop_nums,
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
# print(f"============================================update_inputs_v1 更新后=========================================")
# print(f"model_output.stop_flags : {model_output.stop_flags}")
# print(f"model_output.not_need_stop : {model_output.not_need_stop}")
# print(f"model_output.seq_lens_this_time : {model_output.seq_lens_this_time}")
# print(f"model_output.seq_lens_encoder : {model_output.seq_lens_encoder}")
# print(f"model_output.seq_lens_decoder : {model_output.seq_lens_decoder}")
# print(f"share_inputs['step_seq_lens_decoder'] : {share_inputs['step_seq_lens_decoder']}")
# print(f"share_inputs['prompt_lens'] : {share_inputs['prompt_lens']}")
# print(f"sampled_token_ids : {sampled_token_ids}")
# print(f"model_output.input_ids : {model_output.input_ids}")
# print(f"model_output.stop_nums : {model_output.stop_nums}")
# print(f"model_output.next_tokens : {model_output.next_tokens}")
# print(f"model_output.is_block_step : {model_output.is_block_step}")
# print(f"share_inputs['block_tables'] : {share_inputs['block_tables']}")
# print(f"block_size : {block_size}")
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue. # 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach. # In the future, we will abandon this approach.
if not skip_save_output: if not skip_save_output:
@@ -290,6 +367,96 @@ class XPUModelRunner(ModelRunnerBase):
# Forward meta store the global meta information of the forward # Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None self.forward_meta: ForwardMeta = None
def insert_tasks_v1(self, req_dicts: List[Request]):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
has_prefill_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
continue
else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
continue
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len
)
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
)
if has_prefill_task:
self.share_inputs["not_need_stop"][0] = True
def process_prefill_inputs(self, req_dicts: List[Request]): def process_prefill_inputs(self, req_dicts: List[Request]):
"""Process inputs for prefill tasks and update share_inputs buffer""" """Process inputs for prefill tasks and update share_inputs buffer"""
req_len = len(req_dicts) req_len = len(req_dicts)
@@ -392,6 +559,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["not_need_stop"] = paddle.full( self.share_inputs["not_need_stop"] = paddle.full(
[1], False, dtype="bool" [1], False, dtype="bool"
@@ -455,8 +624,19 @@ class XPUModelRunner(ModelRunnerBase):
dtype="int32", dtype="int32",
) )
def _prepare_inputs(self) -> None: def _prepare_inputs(self, is_dummy_run=False) -> None:
"""prepare the model inputs""" """prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
recover_decode_task(
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"],
self.share_inputs["is_block_step"],
self.parallel_config.block_size,
)
self.forward_meta = xpu_pre_process( self.forward_meta = xpu_pre_process(
self.share_inputs["input_ids"], self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
@@ -655,7 +835,7 @@ class XPUModelRunner(ModelRunnerBase):
intermediate_tensors: intermediate_tensors:
""" """
# 1. Prepare inputs of model and decoder. # 1. Prepare inputs of model and decoder.
self._prepare_inputs() self._prepare_inputs(is_dummy_run=is_dummy_run)
# 2. Padding inputs for cuda grph # 2. Padding inputs for cuda grph
@@ -699,6 +879,8 @@ class XPUModelRunner(ModelRunnerBase):
xpu_post_process( xpu_post_process(
sampled_token_ids=sampler_output.sampled_token_ids, sampled_token_ids=sampler_output.sampled_token_ids,
model_output=model_output_data, model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.parallel_config.block_size,
skip_save_output=is_dummy_run, skip_save_output=is_dummy_run,
) )

View File

@@ -20,6 +20,7 @@ from typing import List, Optional
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -154,7 +155,10 @@ class XpuWorker(WorkerBase):
TODO(gongshaotian):The scheduler should schedule the handling of prefill, TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it. and workers and modelrunners should not perceive it.
""" """
self.model_runner.process_prefill_inputs(req_dicts=req_dicts) if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
else:
self.model_runner.process_prefill_inputs(req_dicts=req_dicts)
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """