Files
FastDeploy/custom_ops/gpu_ops/ngram_mask.cu
2025-06-09 19:20:15 +08:00

480 lines
16 KiB
Plaintext

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HIP
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hiprand.h>
#include <hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#define GPU(str) hip##str
#else
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include "cub/cub.cuh"
#define GPU(str) cuda##str
#endif
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define STOP_LIST_BS 385
template <paddle::DataType D>
class PDTraits;
template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};
class TreeNode {
public:
int token_id_;
TreeNode *children_;
int children_node_len_;
TreeNode() : token_id_(-1), children_node_len_(0) { children_ = nullptr; }
__host__ __device__ bool is_in_tree(TreeNode *node,
int node_num,
int token,
int *idx) {
for (int i = 0; i < node_num; i++) {
if (node->children_) {
TreeNode *tmp = &node->children_[i];
if (tmp->token_id_ == token) {
*idx = i;
return true;
} else if (tmp->token_id_ != -1) {
continue;
} else {
*idx = i;
break;
}
} else {
*idx = i;
break;
}
}
return false;
}
__host__ __device__ void insert(const int *stop_list,
const int stop_list_len) {
TreeNode *node = this;
for (int i = 0; i < stop_list_len; i++) {
int token = stop_list[i];
if (token == -1) break;
int idx;
bool in_tree = is_in_tree(node, STOP_LIST_BS, token, &idx);
if (!in_tree) {
node->children_node_len_++;
if (!node->children_) {
node->children_ = new TreeNode[STOP_LIST_BS];
}
node->children_[idx].token_id_ = token;
}
node = &(node->children_[idx]);
}
}
__host__ __device__ void search(const int *prefix_token,
int token_len,
int *res,
int *res_len) {
TreeNode *node = this;
int idx;
for (int i = 0; i < token_len; i++) {
if (node) {
bool in_tree = is_in_tree(
node, node->children_node_len_, prefix_token[i], &idx);
if (in_tree) {
node = &(node->children_[idx]);
} else {
*res_len = 0;
return;
}
} else {
break;
}
}
if (node) {
int id = 0;
for (int i = 0; i < node->children_node_len_; i++) {
if (node->children_[i].token_id_ != -1) {
res[id++] = node->children_[i].token_id_;
} else {
break;
}
}
*res_len = id;
}
}
void destroy_node(TreeNode *node) {
if (node) {
delete[] node;
}
}
void destroy(TreeNode *head) {
if (head->children_node_len_ == 0) {
// last layer
return;
}
for (int i = 0; i < head->children_node_len_; i++) {
destroy(&head->children_[i]);
}
destroy_node(head->children_);
}
};
class TreeNodeGPU {
public:
int token_id_;
int children_node_len_;
int next;
__device__ void is_in_tree(TreeNodeGPU *node,
TreeNodeGPU *head,
int token,
int *idx) {
for (int i = 0; i < node->children_node_len_; i++) {
TreeNodeGPU *children = head + node->next + i;
if (children->token_id_ == token) {
*idx = node->next + i;
return;
}
}
*idx = -1;
}
__device__ void search(const int *prefix_token,
int token_len,
int *res,
int *res_len) {
TreeNodeGPU *node = this;
int idx;
for (int i = 0; i < token_len; i++) {
is_in_tree(node, this, prefix_token[i], &idx);
if (idx != -1) {
node = this + idx;
} else {
*res_len = 0;
return;
}
}
for (int i = 0; i < node->children_node_len_; i++) {
res[i] = (this + node->next + i)->token_id_;
}
res_len[0] = node->children_node_len_;
}
template <typename T>
__device__ void search(const int *prefix_token, int token_len, T *logits) {
TreeNodeGPU *node = this;
int idx;
for (int i = 0; i < token_len; i++) {
is_in_tree(node, this, prefix_token[i], &idx);
if (idx != -1) {
node = this + idx;
} else {
return;
}
}
for (int i = 0; i < node->children_node_len_; i++) {
// printf("child token: %d\n", (this + node->next + i)->token_id_);
logits[(this + node->next + i)->token_id_] = -10000.;
}
}
};
template <typename T>
__global__ void search_on_gpu(TreeNodeGPU *head,
const int *input_sequences,
T *logits,
const int logits_len,
const int max_input_len) {
int bi = blockIdx.x;
int ti = threadIdx.x;
if (ti < max_input_len) {
int seq_offset = bi * max_input_len;
const int *seq_this_thread =
input_sequences + seq_offset + max_input_len - ti - 1;
T *logits_this_thread = logits + bi * logits_len;
head->search(seq_this_thread, ti + 1, logits_this_thread);
}
}
__global__ void search_on_gpu(TreeNodeGPU *head,
const int *input_sequences,
int *res,
int *res_len,
const int max_input_len) {
int bi = blockIdx.x;
int ti = threadIdx.x;
if (ti < max_input_len) {
int seq_offset = bi * max_input_len;
int res_offset = bi * max_input_len * STOP_LIST_BS + ti * STOP_LIST_BS;
int res_len_offset = bi * max_input_len + ti;
const int *seq_this_thread =
input_sequences + seq_offset + max_input_len - ti - 1;
int *res_this_thread = res + res_offset;
int *res_len_this_thread = res_len + res_len_offset;
head->search(
seq_this_thread, ti + 1, res_this_thread, res_len_this_thread);
}
}
template <typename T>
__global__ void set_value_reverse(const int *res,
const int *res_len,
T *logits,
const int max_input_len,
const int logits_len) {
int bi = blockIdx.x;
T *logits_now = logits + bi * logits_len;
for (int i = threadIdx.x; i < logits_len; i += blockDim.x) {
bool set_flag = true;
const int *res_len_now = res_len + bi * max_input_len;
for (int j = 0; j < max_input_len; j++) {
const int *res_now =
res + bi * max_input_len * STOP_LIST_BS + j * STOP_LIST_BS;
for (int k = 0; k < res_len_now[j]; k++) {
if (i == res_now[k]) {
set_flag = false;
break;
}
}
if (!set_flag) break;
}
if (set_flag) logits_now[i] = -10000.;
}
}
__global__ void setup_gpu_node(TreeNodeGPU *d_head,
const int token_id,
const int childre_node_num,
const int now_id,
const int next_id) {
d_head[now_id].token_id_ = token_id;
d_head[now_id].children_node_len_ = childre_node_num;
d_head[now_id].next = next_id;
}
__global__ void print_kernel(TreeNodeGPU *d_node, int now_id) {
printf("NodeGPU token_id: %d, next_id: %d, children_num: %d\n",
d_node[now_id].token_id_,
d_node[now_id].next,
d_node[now_id].children_node_len_);
}
void get_nodes_num(TreeNode *node, int &res) {
res += node->children_node_len_;
for (int i = 0; i < node->children_node_len_; i++) {
get_nodes_num(&(node->children_[i]), res);
}
}
void setup_tree_cpu(TreeNode *head, const int *stop_list, int stop_list_len) {
for (int i = 0; i < STOP_LIST_BS; i++) {
int offset = i * stop_list_len;
head->insert(stop_list + offset, stop_list_len);
}
}
void copy_tree(TreeNode *head,
TreeNodeGPU *d_head,
int depth,
int now_id,
int &next_id,
gpuStream_t stream) {
if (!head) {
return;
}
int tmp_next_id = next_id;
if (head->children_node_len_ == 0) {
tmp_next_id = -1;
}
setup_gpu_node<<<1, 1, 0, stream>>>(
d_head, head->token_id_, head->children_node_len_, now_id, tmp_next_id);
GPU(DeviceSynchronize)();
depth++;
int next_id_this_arrays = next_id;
next_id += head->children_node_len_;
for (int i = 0; i < head->children_node_len_; i++) {
int tmp_now_id = next_id_this_arrays + i;
copy_tree(
&(head->children_[i]), d_head, depth, tmp_now_id, next_id, stream);
}
}
void search_on_cpu(TreeNode *head,
const int *input_sequences,
int *res,
int *res_len,
const int bs,
const int max_input_len) {
for (int i = 0; i < bs; i++) {
for (int j = 0; j < max_input_len; j++) {
int seq_offset = i * max_input_len;
int res_offset =
i * max_input_len * STOP_LIST_BS + j * STOP_LIST_BS;
int res_len_offset = i * max_input_len + j;
const int *seq_this_time =
input_sequences + seq_offset + max_input_len - j - 1;
int *res_this_time = res + res_offset;
int *res_len_this_time = res_len + res_len_offset;
head->search(
seq_this_time, j + 1, res_this_time, res_len_this_time);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> NgramMaskKernel(
const paddle::Tensor &stop_list_tensor,
const paddle::Tensor &input_sequences,
const paddle::Tensor &logits,
int reverse) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
std::vector<int64_t> stop_list_shape = stop_list_tensor.shape();
std::vector<int64_t> input_shape = input_sequences.shape();
std::vector<int64_t> logits_shape = logits.shape();
int logits_len = logits_shape[1];
auto logits_out = logits.copy_to(logits.place(), false);
auto cu_stream = input_sequences.stream();
int bs = input_shape[0];
int stop_list_len = stop_list_shape[1];
int max_input_len = input_shape[1];
static int run_flag = 0;
static TreeNode *head = new TreeNode();
static TreeNodeGPU *d_head;
if (!run_flag) {
setup_tree_cpu(head, stop_list_tensor.data<int>(), stop_list_len);
int node_num = 0;
get_nodes_num(head, node_num);
node_num++;
printf("node_num: %d\n", node_num);
run_flag++;
GPU(Malloc)(&d_head, node_num * sizeof(TreeNodeGPU));
GPU(DeviceSynchronize)();
int now_id = 0;
int next_id = 1;
copy_tree(head, d_head, 0, now_id, next_id, cu_stream);
GPU(DeviceSynchronize)();
head->destroy(head);
}
int grid_size = bs;
int block_size = max_input_len;
if (reverse) {
auto out_ids = paddle::empty({bs, max_input_len, STOP_LIST_BS},
paddle::DataType::INT32,
input_sequences.place());
auto out_lens = paddle::empty({bs, max_input_len},
paddle::DataType::INT32,
input_sequences.place());
int grid_size = bs;
int block_size = max_input_len;
search_on_gpu<<<grid_size, block_size, 0, cu_stream>>>(
d_head,
input_sequences.data<int>(),
out_ids.data<int>(),
out_lens.data<int>(),
max_input_len);
set_value_reverse<DataType_><<<grid_size, 256, 0, cu_stream>>>(
out_ids.data<int>(),
out_lens.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits_out.data<data_t>())),
max_input_len,
logits_len);
} else {
search_on_gpu<DataType_><<<grid_size, block_size, 0, cu_stream>>>(
d_head,
input_sequences.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits_out.data<data_t>())),
logits_len,
max_input_len);
}
return {logits_out};
}
std::vector<paddle::Tensor> NgramMask(const paddle::Tensor &stop_list_tensor,
const paddle::Tensor &input_sequences,
const paddle::Tensor &logits,
int reverse) {
switch (logits.type()) {
case paddle::DataType::FLOAT16: {
return NgramMaskKernel<paddle::DataType::FLOAT16>(
stop_list_tensor, input_sequences, logits, reverse);
}
case paddle::DataType::FLOAT32: {
return NgramMaskKernel<paddle::DataType::FLOAT32>(
stop_list_tensor, input_sequences, logits, reverse);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and float32 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> NgramMaskInferShape(
const std::vector<int64_t> &stop_list_tensor_shape,
const std::vector<int64_t> &input_sequences_shape,
const std::vector<int64_t> &logits_shape) {
return {logits_shape};
}
std::vector<paddle::DataType> NgramMaskInferDtype(
const paddle::DataType &stop_list_tensor_dtype,
const paddle::DataType &input_sequences_dtype,
const paddle::DataType &logits_dtype) {
return {logits_dtype};
}
PD_BUILD_STATIC_OP(ngram_mask)
.Inputs({"stop_list_tensor", "input_sequences", "logits"})
.Outputs({"logits_out"})
.Attrs({"reverse: int"})
.SetKernelFn(PD_KERNEL(NgramMask))
.SetInferShapeFn(PD_INFER_SHAPE(NgramMaskInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(NgramMaskInferDtype));