mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
This commit is contained in:
@@ -24,16 +24,18 @@
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 10
|
||||
#define K 20
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void GetOutputTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& ranks,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
|
||||
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
|
||||
float* scores_data = const_cast<float*>(scores.data<float>());
|
||||
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
|
||||
scores_data[offset] = msg_rcv.mtext_f[offset];
|
||||
}
|
||||
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_output_topk)
|
||||
.Inputs({"x", "scores"})
|
||||
.Inputs({"x", "scores", "ranks"})
|
||||
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"x_out", "scores_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
|
||||
.Outputs({"x_out", "scores_out", "ranks_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputTopK));
|
||||
|
||||
Reference in New Issue
Block a user