mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature][MTP] Support cacheKV transfer in per_chunk mode (#2890)
* support chunk_prefill both normal and speculative_decoding(mtp) * optimize pd-disaggregation config * fix bug
This commit is contained in:
@@ -36,9 +36,9 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int* out_data = const_cast<int*>(x.data<int>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, IPC_NOWAIT);
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 2 + 2) * 4, 0, 0);
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
out_data[0] = -1;
|
||||
@@ -47,7 +47,7 @@ void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
}
|
||||
int encoder_count = msg_rcv.mtext[0];
|
||||
|
||||
for (int i = 0; i < encoder_count * 2 + 2; i++) {
|
||||
for (int i = 0; i < encoder_count * 3 + 2; i++) {
|
||||
out_data[i] = msg_rcv.mtext[i];
|
||||
}
|
||||
return;
|
||||
|
||||
@@ -35,5 +35,5 @@ struct msgdata {
|
||||
|
||||
struct msgdatakv {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * 2 + 2]; // encoder_count, layer_id, bid- pair
|
||||
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
|
||||
};
|
||||
@@ -64,9 +64,10 @@ struct RemoteCacheKvIpc {
|
||||
int encoder_count = 0;
|
||||
for (int i = 0; i < real_bsz; i++) {
|
||||
if (seq_lens_encoder[i] > 0) {
|
||||
msg_sed.mtext[3 * encoder_count + 2] = i;
|
||||
msg_sed.mtext[3 * encoder_count + 3] = seq_lens_decoder[i];
|
||||
msg_sed.mtext[3 * encoder_count + 4] = seq_lens_encoder[i];
|
||||
encoder_count++;
|
||||
msg_sed.mtext[2 * i + 2] = i;
|
||||
msg_sed.mtext[2 * i + 3] = seq_lens_decoder[i];
|
||||
}
|
||||
}
|
||||
msg_sed.mtext[0] = encoder_count;
|
||||
@@ -82,7 +83,7 @@ struct RemoteCacheKvIpc {
|
||||
|
||||
void CUDART_CB send_signal() {
|
||||
msg_sed.mtext[1] = layer_id_;
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 2 + 2) * 4, 0)) == -1) {
|
||||
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ * 3 + 2) * 4, 0)) == -1) {
|
||||
printf("kv signal full msg buffer\n");
|
||||
}
|
||||
layer_id_ = (layer_id_ + 1);
|
||||
|
||||
Reference in New Issue
Block a user