// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "paddle/extension.h" #include "msg_utils.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif #define MAX_BSZ 512 void GetOutputKVSignal(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { int msg_queue_id = 1024 + rank_id; static struct msgdatakv msg_rcv; static key_t key = ftok("/opt/", msg_queue_id); static int msgid = msgget(key, IPC_CREAT | 0666); int* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); } else { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0); } if (ret == -1) { out_data[0] = -1; out_data[1] = -1; return; } int encoder_count = msg_rcv.mtext[0]; for (int i = 0; i < encoder_count * 3 + 2; i++) { out_data[i] = msg_rcv.mtext[i]; } return; } void GetOutputEp(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id) { static struct msgdata msg_rcv; if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")) { std::string inference_msg_queue_id_env_str( inference_msg_queue_id_env_p); int inference_msg_queue_id_from_env = std::stoi(inference_msg_queue_id_env_str); #ifdef GET_OUTPUT_DEBUG std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " << inference_msg_queue_id_from_env << std::endl; #endif msg_queue_id = inference_msg_queue_id_from_env; } #ifdef GET_OUTPUT_DEBUG std::cout << "msg_queue_id is: " << msg_queue_id << std::endl; #endif // static key_t key = ftok("/dev/shm", msg_queue_id); // static int msgid = msgget(key, IPC_CREAT | 0666); key_t key = ftok("/dev/shm", msg_queue_id); int msgid = msgget(key, IPC_CREAT | 0666); #ifdef GET_OUTPUT_DEBUG std::cout << "get_output_key: " << key << std::endl; std::cout << "get_output msgid: " << msgid << std::endl; #endif int64_t* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); } else { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); } if (ret == -1) { out_data[0] = -2; out_data[1] = 0; return; } int bsz = msg_rcv.mtext[1]; for (int64_t i = 0; i < bsz + 2; i++) { out_data[i] = (int64_t)msg_rcv.mtext[i]; } #ifdef GET_OUTPUT_DEBUG std::cout << "get_output finished: " << msgid << std::endl; #endif return; } void GetOutputEPStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { GetOutputEp(x, rank_id, wait_flag, 1); } void GetOutputEPDynamic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id) { GetOutputEp(x, rank_id, wait_flag, msg_queue_id); } PD_BUILD_STATIC_OP(get_output_ep) .Inputs({"x"}) .Attrs({"rank_id: int64_t", "wait_flag: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(GetOutputEPStatic)); PD_BUILD_STATIC_OP(get_output_ep_dynamic) .Inputs({"x"}) .Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(GetOutputEPDynamic));