// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif #include "speculate_msg.h" void SpeculateGetOutput(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id, bool get_each_rank) { if (!get_each_rank && rank_id > 0) { return; } if (const char* inference_msg_queue_id_env_p = std::getenv("INFERENCE_MSG_QUEUE_ID")) { std::string inference_msg_queue_id_env_str( inference_msg_queue_id_env_p); int inference_msg_queue_id_from_env = std::stoi(inference_msg_queue_id_env_str); #ifdef GET_OUTPUT_DEBUG std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " << inference_msg_queue_id_from_env << std::endl; #endif msg_queue_id = inference_msg_queue_id_from_env; } static struct speculate_msgdata msg_rcv; static key_t key = ftok("./", msg_queue_id); static int msgid = msgget(key, IPC_CREAT | 0666); int64_t* out_data = const_cast(x.data()); int ret = -1; if (!wait_flag) { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); } else { ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2) * 4, 0, 0); } if (ret == -1) { out_data[0] = -2; out_data[1] = 0; return; } int bsz = msg_rcv.mtext[1]; for (int64_t i = 0; i < MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2; i++) { out_data[i] = (int64_t)msg_rcv.mtext[i]; } return; } void SpeculateGetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, bool get_each_rank) { SpeculateGetOutput(x, rank_id, wait_flag, 1, get_each_rank); } void SpeculateGetOutputDynamic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag, int msg_queue_id, bool get_each_rank) { SpeculateGetOutput(x, rank_id, wait_flag, msg_queue_id, get_each_rank); } PD_BUILD_STATIC_OP(speculate_get_output) .Inputs({"x"}) .Attrs({"rank_id: int64_t", "wait_flag: bool", "get_each_rank: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(SpeculateGetOutputStatic)); PD_BUILD_STATIC_OP(speculate_get_output_dynamic) .Inputs({"x"}) .Attrs({"rank_id: int64_t", "wait_flag: bool", "msg_queue_id: int", "get_each_rank: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(SpeculateGetOutputDynamic));