mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support pd ep deployment with yiyan adapter (#4029)
* [Feature] Support mixed deployment with yiyan adapter in release2.2 * fix metrics * add unit test * add unit test * add unit test * Support pd ep deployment with yiyan adapter * Support pd ep deployment with yiyan adapter * refactor cache messager * support scheduler v1 in PD * suppport pd v1 + chunk prefill * suppport pd v1 + chunk prefill * add eplb * support eplb * support eplb * support eplb * support v1 * fix * fix * fix bug * remove eplb support * support prefix cache in P * fix bug * fix bug * support one stop in V1 * fix bug * fix ci * fix ci * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else{
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
|
||||
Reference in New Issue
Block a user