mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Optim PaddleOCR-VL (#4873)
* [Feature] Optim PaddleOCR-VL * fix bug
This commit is contained in:
@@ -1059,6 +1059,15 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
|
||||
const paddle::Tensor& decode_states,
|
||||
const paddle::Tensor& mask_rollback);
|
||||
|
||||
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& cos_emb,
|
||||
const paddle::Tensor& sin_emb,
|
||||
const int num_heads,
|
||||
const int head_dim);
|
||||
|
||||
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("get_expert_token_num",
|
||||
&GetExpertTokenNum,
|
||||
@@ -1648,4 +1657,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("update_attn_mask_offsets",
|
||||
&UpdateAttnMaskOffsets,
|
||||
"update attention mask");
|
||||
|
||||
m.def("fused_neox_rope_embedding",
|
||||
&FusedNeoxRopeEmbedding,
|
||||
"fused_neox_rope_embedding function");
|
||||
|
||||
m.def("gelu_tanh", &GeluTanh, "gelu_tanh function");
|
||||
}
|
||||
|
||||
140
custom_ops/gpu_ops/fused_neox_rope_embedding.cu
Normal file
140
custom_ops/gpu_ops/fused_neox_rope_embedding.cu
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void FusedNeoxRopeEmbeddingKernel(const T *__restrict__ qkv,
|
||||
const float *__restrict__ cos_emb,
|
||||
const float *__restrict__ sin_emb,
|
||||
T *__restrict__ q,
|
||||
T *__restrict__ k,
|
||||
T *__restrict__ v,
|
||||
const int64_t elem_cnt,
|
||||
const int num_head,
|
||||
const int last_dim) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int hidden_size = num_head * half_lastdim;
|
||||
const int full_hidden_size = num_head * last_dim;
|
||||
const int offset = 3 * hidden_size;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int bias = linear_index % offset;
|
||||
const int qkv_id = bias / hidden_size;
|
||||
const int qkv_bias = bias % hidden_size;
|
||||
const int hi = qkv_bias / half_lastdim;
|
||||
const int h_bias = qkv_bias % half_lastdim;
|
||||
const int base_idx_left = token_idx * 3 * full_hidden_size +
|
||||
qkv_id * full_hidden_size + hi * last_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + half_lastdim;
|
||||
const int emb_idx = token_idx * last_dim + h_bias;
|
||||
const int base_split_idx_left =
|
||||
token_idx * full_hidden_size + hi * last_dim + h_bias;
|
||||
const int base_split_idx_right = base_split_idx_left + half_lastdim;
|
||||
|
||||
// q,k,v output
|
||||
T *out_p = nullptr;
|
||||
if (qkv_id == 0) {
|
||||
out_p = q;
|
||||
} else if (qkv_id == 1) {
|
||||
out_p = k;
|
||||
} else {
|
||||
out_p = v;
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
// do rope
|
||||
if (qkv_id < 2) {
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
|
||||
int cur_idx_1 = base_split_idx_left + i;
|
||||
int cur_idx_2 = base_split_idx_right + i;
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &out_p[base_split_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &out_p[base_split_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
|
||||
const paddle::Tensor &qkv,
|
||||
const paddle::Tensor &cos_emb,
|
||||
const paddle::Tensor &sin_emb,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const auto &qkv_dims = qkv.dims();
|
||||
const int token_num = qkv_dims.size() == 2 ? qkv_dims[0] : qkv_dims[1];
|
||||
|
||||
auto stream = qkv.stream();
|
||||
paddle::Tensor q = GetEmptyTensor(
|
||||
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
paddle::Tensor k = GetEmptyTensor(
|
||||
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
paddle::Tensor v = GetEmptyTensor(
|
||||
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
|
||||
int64_t elem_nums = token_num * num_heads * head_dim * 3 / 2;
|
||||
constexpr int PackSize = 4;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
FusedNeoxRopeEmbeddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_ *>(qkv.data<data_t>()),
|
||||
cos_emb.data<float>(),
|
||||
sin_emb.data<float>(),
|
||||
reinterpret_cast<DataType_ *>(q.data<data_t>()),
|
||||
reinterpret_cast<DataType_ *>(k.data<data_t>()),
|
||||
reinterpret_cast<DataType_ *>(v.data<data_t>()),
|
||||
elem_nums,
|
||||
num_heads,
|
||||
head_dim);
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(fused_neox_rope_embedding)
|
||||
.Inputs({"qkv", "cos_emb", "sin_emb"})
|
||||
.Outputs({"q", "k", "v"})
|
||||
.Attrs({"num_heads: int", "head_dim: int"})
|
||||
.SetKernelFn(PD_KERNEL(FusedNeoxRopeEmbedding));
|
||||
106
custom_ops/gpu_ops/gelu_tanh.cu
Normal file
106
custom_ops/gpu_ops/gelu_tanh.cu
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__forceinline__ __device__ float tanh_ptx(float x) {
|
||||
float y;
|
||||
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
|
||||
return y;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float gelu_tanh_func(const float& val) {
|
||||
const float cdf =
|
||||
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
|
||||
(val + 0.044715f * val * val * val))));
|
||||
return val * cdf;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gelu_tanh_kernel(T* __restrict__ out,
|
||||
const T* __restrict__ input,
|
||||
const int d) {
|
||||
constexpr uint32_t kVecSize = 16 / sizeof(T);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t thread_idx = threadIdx.x;
|
||||
const int64_t stride = blockDim.x;
|
||||
const int64_t offset = token_idx * d;
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
|
||||
(__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) {
|
||||
vec_t x_vec;
|
||||
Load(input + offset + idx * kVecSize, &x_vec);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kVecSize; ++i) {
|
||||
x_vec[i] = static_cast<T>(gelu_tanh_func(static_cast<float>(x_vec[i])));
|
||||
}
|
||||
Store(x_vec, out + token_idx * d + idx * kVecSize);
|
||||
}
|
||||
|
||||
const int64_t remaining_offset = d - d % (stride * kVecSize);
|
||||
// process the remaining elements
|
||||
#pragma unroll 1
|
||||
for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) {
|
||||
float x = static_cast<float>(input[offset + remaining_offset + idx]);
|
||||
out[token_idx * d + remaining_offset + idx] =
|
||||
static_cast<T>(gelu_tanh_func(x));
|
||||
}
|
||||
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
|
||||
(__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
|
||||
int d = input.dims()[1];
|
||||
int64_t num_tokens = input.dims()[0];
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
paddle::Tensor output =
|
||||
GetEmptyTensor(input.dims(), input.dtype(), input.place());
|
||||
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = num_tokens;
|
||||
config.blockDim = std::min(d / vec_size, 1024U);
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
|
||||
cudaLaunchKernelEx(&config,
|
||||
gelu_tanh_kernel<scalar_t>,
|
||||
output.data<scalar_t>(),
|
||||
input.data<scalar_t>(),
|
||||
d);
|
||||
});
|
||||
|
||||
return {output};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(gelu_tanh)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output"})
|
||||
.SetKernelFn(PD_KERNEL(GeluTanh));
|
||||
@@ -306,6 +306,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/limit_thinking_content_length_v1.cu",
|
||||
"gpu_ops/limit_thinking_content_length_v2.cu",
|
||||
"gpu_ops/update_attn_mask_offsets.cu",
|
||||
"gpu_ops/fused_neox_rope_embedding.cu",
|
||||
"gpu_ops/gelu_tanh.cu",
|
||||
]
|
||||
|
||||
# pd_disaggregation
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
## 1. Environment Preparation
|
||||
### 1.1 Support Status
|
||||
Recommended Hardware Configuration:
|
||||
- GPU Memory: 12GB or more
|
||||
- Shared Memory: 2GB or more
|
||||
- GPU Memory: 8GB or more
|
||||
- Shared Memory: 4GB or more
|
||||
|
||||
### 1.2 Install Fastdeploy
|
||||
|
||||
@@ -18,38 +18,38 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--max-num-seqs 128
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
**Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-num-seqs 196
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
**Example 3:** Deploying a 16K Context Service on a Single A100 GPU
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
@@ -71,7 +71,7 @@ An example is a set of configurations that can run stably while also delivering
|
||||
> **Available GPU memory ratio during initialization**
|
||||
- **Parameters:** `--gpu-memory-utilization`
|
||||
- **Description:** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup.
|
||||
- **Recommendation:** It is recommended to use 0.8. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
|
||||
- **Recommendation:** It is recommended to use 0.7. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **Parameters:** `--max-num-batched-tokens`
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
## 一、环境准备
|
||||
### 1.1 支持情况
|
||||
推荐硬件配置:
|
||||
- 显存:12GB显存及以上
|
||||
- 共享内存:2G及以上
|
||||
- 显存:8GB显存及以上
|
||||
- 共享内存:4G及以上
|
||||
|
||||
### 1.2 安装fastdeploy
|
||||
|
||||
@@ -18,12 +18,12 @@
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-num-seqs 128
|
||||
```
|
||||
|
||||
@@ -31,25 +31,25 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-num-seqs 196
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
**示例3:** A100上单卡部署16K上下文的服务
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model PaddlePaddle/PaddleOCR-VL \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--port 8185 \
|
||||
--metrics-port 8186 \
|
||||
--engine-worker-queue-port 8187 \
|
||||
--max-model-len 16384 \
|
||||
--max-num-batched-tokens 16384 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
@@ -72,7 +72,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
> **初始化时可用的显存比例**
|
||||
- **参数:** `--gpu-memory-utilization`
|
||||
- **用处:** 用于控制 FastDeploy 初始化服务的可用显存,默认0.9,即预留10%的显存备用。
|
||||
- **推荐:** 推荐使用0.8。如果服务压测时提示显存不足,可以尝试调低该值。
|
||||
- **推荐:** 推荐使用0.7。如果服务压测时提示显存不足,可以尝试调低该值。
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **参数:** `--max-num-batched-tokens`
|
||||
|
||||
@@ -197,7 +197,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}},
|
||||
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
|
||||
{"type": "text", "text": "OCR:"}
|
||||
]}
|
||||
],
|
||||
@@ -216,7 +216,7 @@ response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}},
|
||||
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
|
||||
{"type": "text", "text": "OCR:"}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -22,7 +22,6 @@ import paddle
|
||||
import paddle.nn as nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
@@ -136,12 +135,8 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len
|
||||
else:
|
||||
max_length = fd_config.model_config.max_model_len
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[max_length, fd_config.model_config.hidden_size],
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
@@ -247,12 +242,19 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
input_embeddings = self.get_input_embeddings(
|
||||
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
if forward_meta.step_use_cudagraph:
|
||||
self._decoder_input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._decoder_input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_embeddings=input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -21,39 +21,13 @@ import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddleformers.transformers.activations import ACT2FN
|
||||
from paddleformers.transformers.model_utils import PretrainedModel
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import slice_fn
|
||||
|
||||
from .config import PaddleOCRVisionConfig
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
Dh = x.shape[-1]
|
||||
x1 = x[..., : Dh // 2]
|
||||
x2 = x[..., Dh // 2 :]
|
||||
return paddle.concat([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def _ensure_cos_sin_dim(cos, sin, dim_needed):
|
||||
last = cos.shape[-1]
|
||||
if last == dim_needed:
|
||||
return cos, sin
|
||||
elif last * 2 == dim_needed:
|
||||
cos = paddle.concat([cos, cos], axis=-1)
|
||||
sin = paddle.concat([sin, sin], axis=-1)
|
||||
return cos, sin
|
||||
else:
|
||||
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(x, cos, sin):
|
||||
orig_dtype = x.dtype
|
||||
x = x.astype("float32")
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed.astype(orig_dtype)
|
||||
from .siglip_ops import get_activation_fn, neox_rope_embedding
|
||||
|
||||
|
||||
class SiglipAttention(nn.Layer):
|
||||
@@ -147,29 +121,12 @@ class SiglipAttention(nn.Layer):
|
||||
output_attentions: Optional[bool] = False,
|
||||
cu_seqlens: Optional[List[paddle.Tensor]] = None,
|
||||
max_seqlen: Optional[paddle.Tensor] = None,
|
||||
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
|
||||
cos_emb: Optional[paddle.Tensor] = None, # (cos, sin)
|
||||
sin_emb: Optional[paddle.Tensor] = None, # (cos, sin)
|
||||
):
|
||||
B, seq_length, D = hidden_states.shape
|
||||
|
||||
qkv = (
|
||||
self.qkv_proj(hidden_states)
|
||||
.reshape(
|
||||
[
|
||||
seq_length,
|
||||
3,
|
||||
self.num_heads,
|
||||
-1,
|
||||
]
|
||||
)
|
||||
.transpose(perm=[1, 0, 2, 3])
|
||||
)
|
||||
q, k, v = qkv.unbind(axis=0)
|
||||
cos, sin = rope_emb
|
||||
|
||||
# --------
|
||||
q = apply_rotary_pos_emb_vision(q, cos, sin)
|
||||
k = apply_rotary_pos_emb_vision(k, cos, sin)
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim)
|
||||
attn_output = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
@@ -181,11 +138,9 @@ class SiglipAttention(nn.Layer):
|
||||
causal=False,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
# --------
|
||||
|
||||
attn_output = attn_output.reshape((seq_length, -1))
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
@@ -327,11 +282,7 @@ class SiglipMLP(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.hidden_act == "gelu_pytorch_tanh":
|
||||
config.hidden_act = "gelu_new"
|
||||
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.activation_fn = get_activation_fn(config.hidden_act)
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc1.weight.weight_loader = self.weight_loader
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
@@ -353,7 +304,7 @@ class SiglipMLP(nn.Layer):
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states[0])
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -375,7 +326,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
|
||||
output_attentions=False,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_emb=None,
|
||||
cos_emb=None,
|
||||
sin_emb=None,
|
||||
):
|
||||
|
||||
residual = hidden_states
|
||||
@@ -388,7 +340,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
|
||||
output_attentions=output_attentions,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rope_emb=rope_emb,
|
||||
cos_emb=cos_emb,
|
||||
sin_emb=sin_emb,
|
||||
)
|
||||
|
||||
hs_post_attn = residual + x
|
||||
@@ -545,13 +498,13 @@ class SiglipEncoder(nn.Layer):
|
||||
|
||||
rope_emb = rope_emb_max_grid[pids].flatten(1)
|
||||
rope_emb = rope_emb.tile((1, 2))
|
||||
cos = rope_emb.cos().astype("float32")
|
||||
sin = rope_emb.sin().astype("float32")
|
||||
cos = cos.unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2)
|
||||
rope_emb = (cos, sin)
|
||||
cos_emb = rope_emb.cos().astype("float32")
|
||||
sin_emb = rope_emb.sin().astype("float32")
|
||||
cos_emb = cos_emb.unsqueeze(-2)
|
||||
sin_emb = sin_emb.unsqueeze(-2)
|
||||
else:
|
||||
rope_emb = None
|
||||
cos_emb = None
|
||||
sin_emb = None
|
||||
|
||||
window_indices, cu_seqlens_within_windows = None, None
|
||||
|
||||
@@ -588,7 +541,8 @@ class SiglipEncoder(nn.Layer):
|
||||
output_attentions=output_attentions,
|
||||
cu_seqlens=attn_cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rope_emb=rope_emb,
|
||||
cos_emb=cos_emb,
|
||||
sin_emb=sin_emb,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
||||
74
fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py
Normal file
74
fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import paddle
|
||||
from paddleformers.transformers.activations import ACT2FN
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding, gelu_tanh
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
Dh = x.shape[-1]
|
||||
x1 = x[..., : Dh // 2]
|
||||
x2 = x[..., Dh // 2 :]
|
||||
return paddle.concat([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(x, cos, sin):
|
||||
orig_dtype = x.dtype
|
||||
x = x.astype("float32")
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed.astype(orig_dtype)
|
||||
|
||||
|
||||
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
|
||||
B, seq_length, D = qkv.shape
|
||||
qkv = qkv.reshape(
|
||||
[
|
||||
seq_length,
|
||||
3,
|
||||
num_heads,
|
||||
-1,
|
||||
]
|
||||
).transpose(perm=[1, 0, 2, 3])
|
||||
q, k, v = qkv.unbind(axis=0)
|
||||
q = apply_rotary_pos_emb_vision(q, cos, sin)
|
||||
k = apply_rotary_pos_emb_vision(k, cos, sin)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def neox_rope_embedding(
|
||||
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
|
||||
) -> List[paddle.Tensor]:
|
||||
if current_platform.is_cuda():
|
||||
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
|
||||
else:
|
||||
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
|
||||
|
||||
|
||||
def get_activation_fn(hidden_act: str):
|
||||
if hidden_act == "gelu_pytorch_tanh":
|
||||
if current_platform.is_cuda():
|
||||
return gelu_tanh
|
||||
else:
|
||||
return ACT2FN["gelu_new"]
|
||||
else:
|
||||
return ACT2FN[hidden_act]
|
||||
88
tests/operators/test_fused_neox_rope_embedding.py
Normal file
88
tests/operators/test_fused_neox_rope_embedding.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
Dh = x.shape[-1]
|
||||
x1 = x[..., : Dh // 2]
|
||||
x2 = x[..., Dh // 2 :]
|
||||
return paddle.concat([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(x, cos, sin):
|
||||
orig_dtype = x.dtype
|
||||
x = x.astype("float32")
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed.astype(orig_dtype)
|
||||
|
||||
|
||||
class TestFusedNeoxRopeEmbedding(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.set_device("gpu")
|
||||
np.random.seed(42)
|
||||
|
||||
def native_neox_rope_embedding(self, qkv, cos, sin, num_heads):
|
||||
seq_length = qkv.shape[0]
|
||||
qkv = qkv.reshape(
|
||||
[
|
||||
seq_length,
|
||||
3,
|
||||
num_heads,
|
||||
-1,
|
||||
]
|
||||
).transpose(perm=[1, 0, 2, 3])
|
||||
q, k, v = qkv.unbind(axis=0)
|
||||
q = apply_rotary_pos_emb_vision(q, cos, sin)
|
||||
k = apply_rotary_pos_emb_vision(k, cos, sin)
|
||||
return q, k, v
|
||||
|
||||
def test_fused_neox_rope_embedding(self):
|
||||
token_num = 1024
|
||||
hidden_size = 2048
|
||||
head_dim = 128
|
||||
num_heads = hidden_size // head_dim
|
||||
qkv = paddle.randn([token_num, 3 * hidden_size]).astype("bfloat16")
|
||||
cos_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
|
||||
sin_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
|
||||
q, k, v = fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
|
||||
q_base, k_base, v_base = self.native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
|
||||
np.testing.assert_allclose(
|
||||
q.cast("float32").numpy(),
|
||||
q_base.cast("float32").numpy(),
|
||||
rtol=1e-02,
|
||||
atol=1e-02,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
k.cast("float32").numpy(),
|
||||
k_base.cast("float32").numpy(),
|
||||
rtol=1e-02,
|
||||
atol=1e-02,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
v.cast("float32").numpy(),
|
||||
v_base.cast("float32").numpy(),
|
||||
rtol=1e-02,
|
||||
atol=1e-02,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
42
tests/operators/test_gelu_tanh.py
Normal file
42
tests/operators/test_gelu_tanh.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.transformers.activations import ACT2FN
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import gelu_tanh
|
||||
|
||||
|
||||
class TestGeluTanh(unittest.TestCase):
|
||||
def setUp(self):
|
||||
paddle.set_device("gpu")
|
||||
np.random.seed(42)
|
||||
|
||||
def test_gelu_tanh(self):
|
||||
x = paddle.randn(2048, 4096)
|
||||
y0 = ACT2FN["gelu_new"](x)
|
||||
y1 = gelu_tanh(x)
|
||||
np.testing.assert_allclose(
|
||||
y0.cast("float32").numpy(),
|
||||
y1.cast("float32").numpy(),
|
||||
rtol=1e-04,
|
||||
atol=1e-04,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user