[Feature] Online Chat API Support Return logprobs (#2777)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner and xpu_model_runner

* get_worker() check platform
This commit is contained in:
chen
2025-07-10 16:33:40 +08:00
committed by GitHub
parent 24f934f1f9
commit d33105baeb
22 changed files with 608 additions and 114 deletions

View File

@@ -24,16 +24,18 @@
#endif
#define MAX_BSZ 512
#define K 10
#define K 20
struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};
void GetOutputTopK(const paddle::Tensor& x,
const paddle::Tensor& scores,
const paddle::Tensor& ranks,
int k,
int64_t rank_id,
bool wait_flag) {
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
float* scores_data = const_cast<float*>(scores.data<float>());
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
0,
0);
}
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
}
return;
}
PD_BUILD_STATIC_OP(get_output_topk)
.Inputs({"x", "scores"})
.Inputs({"x", "scores", "ranks"})
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out", "scores_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
.Outputs({"x_out", "scores_out", "ranks_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
.SetKernelFn(PD_KERNEL(GetOutputTopK));