[Feature] Optim PaddleOCR-VL (#4872)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

This commit is contained in:
ming1753
2025-11-07 17:55:02 +08:00
committed by GitHub
parent 329e999f2d
commit a7ef998e04
13 changed files with 540 additions and 113 deletions

View File

@@ -1046,6 +1046,15 @@ void SpeculateGetTargetLogits(const paddle::Tensor& target_logits,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& accept_num);
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor& qkv,
const paddle::Tensor& cos_emb,
const paddle::Tensor& sin_emb,
const int num_heads,
const int head_dim);
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num",
&GetExpertTokenNum,
@@ -1631,4 +1640,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_get_target_logits",
&SpeculateGetTargetLogits,
"speculate_get_target_logits function");
m.def("fused_neox_rope_embedding",
&FusedNeoxRopeEmbedding,
"fused_neox_rope_embedding function");
m.def("gelu_tanh", &GeluTanh, "gelu_tanh function");
}

View File

@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
template <typename T, int VecSize = 1>
__global__ void FusedNeoxRopeEmbeddingKernel(const T *__restrict__ qkv,
const float *__restrict__ cos_emb,
const float *__restrict__ sin_emb,
T *__restrict__ q,
T *__restrict__ k,
T *__restrict__ v,
const int64_t elem_cnt,
const int num_head,
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_lastdim = last_dim / 2;
const int hidden_size = num_head * half_lastdim;
const int full_hidden_size = num_head * last_dim;
const int offset = 3 * hidden_size;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
const int qkv_bias = bias % hidden_size;
const int hi = qkv_bias / half_lastdim;
const int h_bias = qkv_bias % half_lastdim;
const int base_idx_left = token_idx * 3 * full_hidden_size +
qkv_id * full_hidden_size + hi * last_dim +
h_bias;
const int base_idx_right = base_idx_left + half_lastdim;
const int emb_idx = token_idx * last_dim + h_bias;
const int base_split_idx_left =
token_idx * full_hidden_size + hi * last_dim + h_bias;
const int base_split_idx_right = base_split_idx_left + half_lastdim;
// q,k,v output
T *out_p = nullptr;
if (qkv_id == 0) {
out_p = q;
} else if (qkv_id == 1) {
out_p = k;
} else {
out_p = v;
}
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
// do rope
if (qkv_id < 2) {
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
int cur_idx_1 = base_split_idx_left + i;
int cur_idx_2 = base_split_idx_right + i;
}
}
Store<T, VecSize>(left_vec, &out_p[base_split_idx_left]);
Store<T, VecSize>(right_vec, &out_p[base_split_idx_right]);
}
}
std::vector<paddle::Tensor> FusedNeoxRopeEmbedding(
const paddle::Tensor &qkv,
const paddle::Tensor &cos_emb,
const paddle::Tensor &sin_emb,
const int num_heads,
const int head_dim) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const auto &qkv_dims = qkv.dims();
const int token_num = qkv_dims.size() == 2 ? qkv_dims[0] : qkv_dims[1];
auto stream = qkv.stream();
paddle::Tensor q = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor k = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
paddle::Tensor v = GetEmptyTensor(
{token_num, num_heads, head_dim}, qkv.dtype(), qkv.place());
int64_t elem_nums = token_num * num_heads * head_dim * 3 / 2;
constexpr int PackSize = 4;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
FusedNeoxRopeEmbeddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const DataType_ *>(qkv.data<data_t>()),
cos_emb.data<float>(),
sin_emb.data<float>(),
reinterpret_cast<DataType_ *>(q.data<data_t>()),
reinterpret_cast<DataType_ *>(k.data<data_t>()),
reinterpret_cast<DataType_ *>(v.data<data_t>()),
elem_nums,
num_heads,
head_dim);
return {q, k, v};
}
PD_BUILD_STATIC_OP(fused_neox_rope_embedding)
.Inputs({"qkv", "cos_emb", "sin_emb"})
.Outputs({"q", "k", "v"})
.Attrs({"num_heads: int", "head_dim: int"})
.SetKernelFn(PD_KERNEL(FusedNeoxRopeEmbedding));

View File

@@ -0,0 +1,106 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__forceinline__ __device__ float tanh_ptx(float x) {
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}
__device__ __forceinline__ float gelu_tanh_func(const float& val) {
const float cdf =
0.5f * (1.0f + tanh_ptx((0.7978845608028654f *
(val + 0.044715f * val * val * val))));
return val * cdf;
}
template <typename T>
__global__ void gelu_tanh_kernel(T* __restrict__ out,
const T* __restrict__ input,
const int d) {
constexpr uint32_t kVecSize = 16 / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
using vec_t = AlignedVector<T, kVecSize>;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / kVecSize; idx += stride) {
vec_t x_vec;
Load(input + offset + idx * kVecSize, &x_vec);
#pragma unroll
for (uint32_t i = 0; i < kVecSize; ++i) {
x_vec[i] = static_cast<T>(gelu_tanh_func(static_cast<float>(x_vec[i])));
}
Store(x_vec, out + token_idx * d + idx * kVecSize);
}
const int64_t remaining_offset = d - d % (stride * kVecSize);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * kVecSize); idx += stride) {
float x = static_cast<float>(input[offset + remaining_offset + idx]);
out[token_idx * d + remaining_offset + idx] =
static_cast<T>(gelu_tanh_func(x));
}
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
int d = input.dims()[1];
int64_t num_tokens = input.dims()[0];
cudaStream_t stream = input.stream();
paddle::Tensor output =
GetEmptyTensor(input.dims(), input.dtype(), input.place());
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
gelu_tanh_kernel<scalar_t>,
output.data<scalar_t>(),
input.data<scalar_t>(),
d);
});
return {output};
}
PD_BUILD_STATIC_OP(gelu_tanh)
.Inputs({"input"})
.Outputs({"output"})
.SetKernelFn(PD_KERNEL(GeluTanh));

View File

@@ -305,6 +305,8 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/merge_prefill_decode_output.cu",
"gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/fused_neox_rope_embedding.cu",
"gpu_ops/gelu_tanh.cu",
]
# pd_disaggregation

View File

@@ -5,8 +5,8 @@
## 1. Environment Preparation
### 1.1 Support Status
Recommended Hardware Configuration:
- GPU Memory: 12GB or more
- Shared Memory: 2GB or more
- GPU Memory: 8GB or more
- Shared Memory: 4GB or more
### 1.2 Install Fastdeploy
@@ -18,38 +18,38 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.9 \
--max-num-seqs 128
--gpu-memory-utilization 0.8 \
--max-num-seqs 256
```
**Example 2:** Deploying a 16K Context Service on a Single RTX 4090 GPU
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--max-num-seqs 196
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```
**Example 3:** Deploying a 16K Context Service on a Single A100 GPU
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```
@@ -71,7 +71,7 @@ An example is a set of configurations that can run stably while also delivering
> **Available GPU memory ratio during initialization**
- **Parameters** `--gpu-memory-utilization`
- **Description** Controls the available GPU memory for FastDeploy service initialization. The default value is 0.9, meaning 10% of the memory is reserved for backup.
- **Recommendation** It is recommended to use 0.8. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
- **Recommendation** It is recommended to use 0.7. If an "out of memory" error occurs during stress testing, you may attempt to reduce this value.
#### 2.2.2 Chunked Prefill
- **Parameters** `--max-num-batched-tokens`

View File

@@ -5,8 +5,8 @@
## 一、环境准备
### 1.1 支持情况
推荐硬件配置:
- 显存:12GB显存及以上
- 共享内存:2G及以上
- 显存:8GB显存及以上
- 共享内存:4G及以上
### 1.2 安装fastdeploy
@@ -18,12 +18,12 @@
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.9 \
--gpu-memory-utilization 0.8 \
--max-num-seqs 128
```
@@ -31,25 +31,25 @@ python -m fastdeploy.entrypoints.openai.api_server \
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--max-num-seqs 196
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```
**示例3** A100上单卡部署16K上下文的服务
```shell
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/PaddleOCR-VL \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--port 8185 \
--metrics-port 8186 \
--engine-worker-queue-port 8187 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--gpu-memory-utilization 0.8 \
--gpu-memory-utilization 0.7 \
--max-num-seqs 256
```
@@ -72,7 +72,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
> **初始化时可用的显存比例**
- **参数:** `--gpu-memory-utilization`
- **用处:** 用于控制 FastDeploy 初始化服务的可用显存默认0.9即预留10%的显存备用。
- **推荐:** 推荐使用0.8。如果服务压测时提示显存不足,可以尝试调低该值。
- **推荐:** 推荐使用0.7。如果服务压测时提示显存不足,可以尝试调低该值。
#### 2.2.2 Chunked Prefill
- **参数:** `--max-num-batched-tokens`

View File

@@ -197,7 +197,7 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-d '{
"messages": [
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}},
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
{"type": "text", "text": "OCR:"}
]}
],
@@ -216,7 +216,7 @@ response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg"}},
{"type": "image_url", "image_url": {"url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"}},
{"type": "text", "text": "OCR:"}
]
},

View File

@@ -22,6 +22,8 @@ from fastdeploy.utils import data_processor_logger
from .process import DataProcessor
_SAMPLING_EPS = 1e-5
class PaddleOCRVLProcessor(TextProcessor):
"""
@@ -61,7 +63,6 @@ class PaddleOCRVLProcessor(TextProcessor):
tool_parser_obj: Tool parser instance
"""
super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj)
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs)
self.processor = DataProcessor(
@@ -252,6 +253,9 @@ class PaddleOCRVLProcessor(TextProcessor):
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token
if request.get("top_p") is not None and request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS
return request
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):

View File

@@ -25,7 +25,6 @@ from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
@@ -154,12 +153,8 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
)
# Persistent buffers for CUDA graphs.
if envs.FD_ENABLE_MAX_PREFILL:
max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len
else:
max_length = fd_config.model_config.max_model_len
self._input_embeddings = paddle.zeros(
[max_length, fd_config.model_config.hidden_size],
self._decoder_input_embeddings = paddle.zeros(
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
@@ -265,12 +260,19 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)
hidden_states = self.model(
input_embeddings=self._input_embeddings,
forward_meta=forward_meta,
)
if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
hidden_states = self.model(
input_embeddings=self._decoder_input_embeddings,
forward_meta=forward_meta,
)
else:
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)
return hidden_states

View File

@@ -21,39 +21,13 @@ import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleformers.transformers.activations import ACT2FN
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from .config import PaddleOCRVisionConfig
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def _ensure_cos_sin_dim(cos, sin, dim_needed):
last = cos.shape[-1]
if last == dim_needed:
return cos, sin
elif last * 2 == dim_needed:
cos = paddle.concat([cos, cos], axis=-1)
sin = paddle.concat([sin, sin], axis=-1)
return cos, sin
else:
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
from .siglip_ops import get_activation_fn, neox_rope_embedding
class SiglipAttention(nn.Layer):
@@ -147,29 +121,12 @@ class SiglipAttention(nn.Layer):
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[paddle.Tensor]] = None,
max_seqlen: Optional[paddle.Tensor] = None,
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
cos_emb: Optional[paddle.Tensor] = None, # (cos, sin)
sin_emb: Optional[paddle.Tensor] = None, # (cos, sin)
):
B, seq_length, D = hidden_states.shape
qkv = (
self.qkv_proj(hidden_states)
.reshape(
[
seq_length,
3,
self.num_heads,
-1,
]
)
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
cos, sin = rope_emb
# --------
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
qkv = self.qkv_proj(hidden_states)
q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim)
attn_output = self.flash_attn_func(
q,
k,
@@ -181,11 +138,9 @@ class SiglipAttention(nn.Layer):
causal=False,
**self.flash_attn_kwargs,
)[0]
# --------
attn_output = attn_output.reshape((seq_length, -1))
attn_output = self.out_proj(attn_output)
return attn_output
@@ -327,11 +282,7 @@ class SiglipMLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
if config.hidden_act == "gelu_pytorch_tanh":
config.hidden_act = "gelu_new"
self.activation_fn = ACT2FN[config.hidden_act]
self.activation_fn = get_activation_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -353,7 +304,7 @@ class SiglipMLP(nn.Layer):
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.activation_fn(hidden_states[0])
hidden_states = self.fc2(hidden_states)
return hidden_states
@@ -375,7 +326,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=False,
cu_seqlens=None,
max_seqlen=None,
rope_emb=None,
cos_emb=None,
sin_emb=None,
):
residual = hidden_states
@@ -388,7 +340,8 @@ class SiglipEncoderLayer(paddle.nn.Layer):
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
hs_post_attn = residual + x
@@ -545,13 +498,13 @@ class SiglipEncoder(nn.Layer):
rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.tile((1, 2))
cos = rope_emb.cos().astype("float32")
sin = rope_emb.sin().astype("float32")
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
rope_emb = (cos, sin)
cos_emb = rope_emb.cos().astype("float32")
sin_emb = rope_emb.sin().astype("float32")
cos_emb = cos_emb.unsqueeze(-2)
sin_emb = sin_emb.unsqueeze(-2)
else:
rope_emb = None
cos_emb = None
sin_emb = None
window_indices, cu_seqlens_within_windows = None, None
@@ -588,7 +541,8 @@ class SiglipEncoder(nn.Layer):
output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
hidden_states = layer_outputs[0]

View File

@@ -0,0 +1,74 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import List
import paddle
from paddleformers.transformers.activations import ACT2FN
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding, gelu_tanh
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
B, seq_length, D = qkv.shape
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
return q, k, v
def neox_rope_embedding(
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
) -> List[paddle.Tensor]:
if current_platform.is_cuda():
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
else:
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
def get_activation_fn(hidden_act: str):
if hidden_act == "gelu_pytorch_tanh":
if current_platform.is_cuda():
return gelu_tanh
else:
return ACT2FN["gelu_new"]
else:
return ACT2FN[hidden_act]

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import fused_neox_rope_embedding
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
class TestFusedNeoxRopeEmbedding(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
np.random.seed(42)
def native_neox_rope_embedding(self, qkv, cos, sin, num_heads):
seq_length = qkv.shape[0]
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
return q, k, v
def test_fused_neox_rope_embedding(self):
token_num = 1024
hidden_size = 2048
head_dim = 128
num_heads = hidden_size // head_dim
qkv = paddle.randn([token_num, 3 * hidden_size]).astype("bfloat16")
cos_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
sin_emb = paddle.rand([token_num, head_dim // 2]).tile((1, 2)).unsqueeze(1)
q, k, v = fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
q_base, k_base, v_base = self.native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
np.testing.assert_allclose(
q.cast("float32").numpy(),
q_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
np.testing.assert_allclose(
k.cast("float32").numpy(),
k_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
np.testing.assert_allclose(
v.cast("float32").numpy(),
v_base.cast("float32").numpy(),
rtol=1e-02,
atol=1e-02,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,42 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddleformers.transformers.activations import ACT2FN
from fastdeploy.model_executor.ops.gpu import gelu_tanh
class TestGeluTanh(unittest.TestCase):
def setUp(self):
paddle.set_device("gpu")
np.random.seed(42)
def test_gelu_tanh(self):
x = paddle.randn(2048, 4096)
y0 = ACT2FN["gelu_new"](x)
y1 = gelu_tanh(x)
np.testing.assert_allclose(
y0.cast("float32").numpy(),
y1.cast("float32").numpy(),
rtol=1e-04,
atol=1e-04,
)
if __name__ == "__main__":
unittest.main()