mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
148 lines
5.7 KiB
Plaintext
148 lines
5.7 KiB
Plaintext
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "helper.h"
|
|
#include "paddle/extension.h"
|
|
|
|
|
|
template <typename T, bool IS_NEOX>
|
|
inline __device__ void apply_token_rotary_embedding_kernel(
|
|
T* __restrict__ arr,
|
|
const T* __restrict__ cos_ptr,
|
|
const T* __restrict__ sin_ptr,
|
|
int rot_offset,
|
|
int embed_dim) {
|
|
int x_index, y_index;
|
|
T cos, sin;
|
|
if (IS_NEOX) {
|
|
x_index = rot_offset;
|
|
y_index = embed_dim + rot_offset;
|
|
cos = cos_ptr[x_index];
|
|
sin = sin_ptr[x_index];
|
|
} else {
|
|
x_index = 2 * rot_offset;
|
|
y_index = 2 * rot_offset + 1;
|
|
cos = cos_ptr[x_index / 2];
|
|
sin = sin_ptr[x_index / 2];
|
|
}
|
|
|
|
const T x = arr[x_index];
|
|
const T y = arr[y_index];
|
|
arr[x_index] = x * cos - y * sin;
|
|
arr[y_index] = y * cos + x * sin;
|
|
}
|
|
|
|
|
|
template <typename T, bool IS_NEOX>
|
|
__global__ void apply_rotary_embedding_kernel(
|
|
T* __restrict__ query, // [num_tokens, num_heads, head_size]
|
|
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
|
const int* __restrict__ position_ids, // [num_tokens]
|
|
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
|
const int rot_dim,
|
|
const int64_t query_stride,
|
|
const int64_t key_stride,
|
|
const int num_heads,
|
|
const int num_kv_heads,
|
|
const int head_size) {
|
|
// Each thread block is responsible for one token.
|
|
const int token_idx = blockIdx.x;
|
|
int pos = position_ids[token_idx];
|
|
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
|
|
|
const int embed_dim = rot_dim / 2;
|
|
const T* cos_ptr = cache_ptr;
|
|
const T* sin_ptr = cache_ptr + embed_dim;
|
|
|
|
const int nq = num_heads * embed_dim;
|
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
|
const int head_idx = i / embed_dim;
|
|
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
|
const int rot_offset = i % embed_dim;
|
|
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
|
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
|
}
|
|
|
|
const int nk = num_kv_heads * embed_dim;
|
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
|
const int head_idx = i / embed_dim;
|
|
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
|
const int rot_offset = i % embed_dim;
|
|
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
|
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
|
}
|
|
}
|
|
|
|
|
|
void FusedRotaryPositionEncoding(
|
|
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
|
// [num_tokens, num_heads * head_size]
|
|
paddle::Tensor& key,
|
|
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
|
// head_size]
|
|
const paddle::Tensor& position_ids, // [num_tokens]
|
|
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
|
int head_size,
|
|
bool is_neox) {
|
|
int64_t num_tokens = query.dims()[0];
|
|
int num_heads = query.numel() / num_tokens / head_size;
|
|
int num_kv_heads = key.numel() / num_tokens / head_size;
|
|
int rot_dim = cos_sin_cache.dims()[1];
|
|
int64_t query_stride = num_heads * head_size;
|
|
int64_t key_stride = num_kv_heads * head_size;
|
|
|
|
if (num_tokens > 65535) {
|
|
PD_THROW(
|
|
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
|
}
|
|
|
|
dim3 grid(num_tokens);
|
|
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
|
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
|
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
|
if (is_neox) {
|
|
apply_rotary_embedding_kernel<data_t, true>
|
|
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
|
key.data<data_t>(),
|
|
position_ids.data<int>(),
|
|
cos_sin_cache.data<data_t>(),
|
|
rot_dim,
|
|
query_stride,
|
|
key_stride,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_size);
|
|
} else {
|
|
apply_rotary_embedding_kernel<data_t, false>
|
|
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
|
key.data<data_t>(),
|
|
position_ids.data<int>(),
|
|
cos_sin_cache.data<data_t>(),
|
|
rot_dim,
|
|
query_stride,
|
|
key_stride,
|
|
num_heads,
|
|
num_kv_heads,
|
|
head_size);
|
|
}
|
|
});
|
|
}
|
|
|
|
PD_BUILD_STATIC_OP(fused_rotary_position_encoding)
|
|
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
|
.Outputs({"query_out", "key_out"})
|
|
.Attrs({"head_size: int", "is_neox: bool"})
|
|
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
|
|
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
|