Files
FastDeploy/fastdeploy/model_executor/models/paddleocr_vl/siglip.py
ming1753 7681375a19 [BugFix] PaddleOCR-VL fix FD_DEBUG type and support v1 loader (#4605)
* [Bug Fix] PaddleOCRVL fix FD_DEBUG type and support HF model

* fix bug

* fix bug

* fix bug
2025-10-28 09:47:47 +08:00

771 lines
30 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from typing import List, Optional, Tuple, Union
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleformers.transformers.activations import ACT2FN
from paddleformers.transformers.model_utils import PretrainedModel
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from .config import PaddleOCRVisionConfig
def rotate_half(x):
Dh = x.shape[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
def _ensure_cos_sin_dim(cos, sin, dim_needed):
last = cos.shape[-1]
if last == dim_needed:
return cos, sin
elif last * 2 == dim_needed:
cos = paddle.concat([cos, cos], axis=-1)
sin = paddle.concat([sin, sin], axis=-1)
return cos, sin
else:
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
def apply_rotary_pos_emb_vision(x, cos, sin):
orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed.astype(orig_dtype)
class SiglipAttention(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim
self.scale = self.head_dim**-0.5
# qkv_linear
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias_attr=True)
self.qkv_proj.weight.weight_loader = self.qkv_weight_loader
self.qkv_proj.bias.weight_loader = self.qkv_weight_loader
# out_linear
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj.weight.weight_loader = self.out_proj_weight_loader
enable_fa3 = False
flash_attn_version = int(os.environ.get("FLAGS_flash_attn_version", "2"))
if flash_attn_version == 3:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
enable_fa3 = is_current_sm_supported and is_paddle_supported
if enable_fa3:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
self.flash_attn_func = flash_attention_v3_varlen
self.flash_attn_kwargs = {}
else:
from paddle.nn.functional.flash_attention import flash_attn_unpadded
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.scale, "training": False}
def qkv_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# Tensor parallelism splits the weight along the output_dim
loaded_weight = get_tensor(loaded_weight)
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.transpose([1, 0])
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q":
param_shard_offset = 0
param_shard_size = self.num_heads * self.head_dim
elif loaded_shard_id == "k":
param_shard_offset = self.num_heads * self.head_dim
param_shard_size = self.num_heads * self.head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = self.num_heads * self.head_dim * 2
param_shard_size = self.num_heads * self.head_dim
param = slice_fn(param, -1, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def out_proj_weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def forward(
self,
hidden_states: paddle.Tensor, # [B, L, D]
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[paddle.Tensor]] = None,
max_seqlen: Optional[paddle.Tensor] = None,
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
):
B, seq_length, D = hidden_states.shape
qkv = (
self.qkv_proj(hidden_states)
.reshape(
[
seq_length,
3,
self.num_heads,
-1,
]
)
.transpose(perm=[1, 0, 2, 3])
)
q, k, v = qkv.unbind(axis=0)
cos, sin = rope_emb
# --------
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
attn_output = self.flash_attn_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
causal=False,
**self.flash_attn_kwargs,
)[0]
# --------
attn_output = attn_output.reshape((seq_length, -1))
attn_output = self.out_proj(attn_output)
return attn_output
class SiglipVisionEmbeddings(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size # 1152
self.image_size = config.image_size # 384
self.patch_size = config.patch_size # 14
self.patch_embedding = nn.Conv2D(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="VALID",
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # 729
self.num_positions = self.num_patches
self.cache_position_embedding = dict()
self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer(
"position_ids",
paddle.arange(self.num_positions).unsqueeze(0),
persistable=False,
)
def interpolate_pos_encoding(self, embeddings, height: int, width: int, is_after_patchify: bool = False):
num_positions = self.position_embedding.weight.shape[0]
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
if is_after_patchify:
new_height = height
new_width = width
else:
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = paddle.to_tensor(num_positions**0.5, dtype=paddle.int64)
patch_pos_embed = patch_pos_embed.reshape((1, sqrt_num_positions, sqrt_num_positions, dim))
patch_pos_embed = patch_pos_embed.transpose((0, 3, 1, 2))
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bilinear",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.transpose((0, 2, 3, 1)).reshape((1, -1, dim))
return patch_pos_embed
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache=20):
grid = (h, w)
if grid in self.cache_position_embedding:
self.cache_position_count[grid] += 1
return self.cache_position_embedding[grid]
if len(self.cache_position_embedding) >= max_cache:
min_hit_grid = min(self.cache_position_count, key=self.cache_position_count.get)
self.cache_position_count.pop(min_hit_grid)
self.cache_position_embedding.pop(min_hit_grid)
position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
self.cache_position_count[grid] = 1
self.cache_position_embedding[grid] = position_embedding
return position_embedding
def forward(
self,
pixel_values: paddle.Tensor, # [B, L, C, H, W]
position_ids: Optional[paddle.Tensor] = None, # [B or 1, S]
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
interpolate_pos_encoding: bool = False,
) -> paddle.Tensor:
if pixel_values.dim() == 4:
pixel_values = pixel_values.unsqueeze(0)
if pixel_values.dim() == 5:
assert position_ids is not None
from einops import rearrange
batch_size, squence_len, channel, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(-2).squeeze(-1)
embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len)
# todo: not debug
if interpolate_pos_encoding and image_grid_thw is not None:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
flatten_image_grid_thw = np.array(flatten_image_grid_thw)
assert batch_size == 1
start = 0
assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], (
flatten_image_grid_thw,
embeddings.shape,
)
embeddings = embeddings.squeeze(0)
tmp_embeddings = list()
for image_grid in image_grid_thw:
t, h, w = image_grid
end = start + t * h * w
image_embeddings = embeddings[int(start) : int(end), :]
position_embedding = (
self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).tile((t, 1))
).astype(image_embeddings.dtype)
image_embeddings = image_embeddings + position_embedding
tmp_embeddings.append(image_embeddings)
start = end
embeddings = paddle.concat(tmp_embeddings, axis=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings
else:
raise NotImplementedError(str(pixel_values.shape))
class SiglipMLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
if config.hidden_act == "gelu_pytorch_tanh":
config.hidden_act = "gelu_new"
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.fc2.weight.weight_loader = self.weight_loader
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
# Ensure loaded weight dtype matches model param dtype
if loaded_weight.dtype != param.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
param.copy_(loaded_weight, False)
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(paddle.nn.Layer):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
self.self_attn = SiglipAttention(config)
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
# @paddle.jit.to_static
def forward(
self,
hidden_states,
attention_mask,
output_attentions=False,
cu_seqlens=None,
max_seqlen=None,
rope_emb=None,
):
residual = hidden_states
############################
ln1_out = self.layer_norm1(hidden_states)
x = self.self_attn(
hidden_states=ln1_out,
attention_mask=attention_mask,
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
)
hs_post_attn = residual + x
residual = hs_post_attn
ln2_out = self.layer_norm2(residual)
mlp_out = self.mlp(ln2_out)
hidden_states_out = residual + mlp_out
outputs = (hidden_states_out,)
return outputs
class SigLIPRotaryEmbedding(nn.Layer):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
self.rope_init()
def rope_init(self):
arange = paddle.arange(0, self.dim, 2, dtype="float32")
inv_freq = 1.0 / (self.theta ** (arange / self.dim))
self.register_buffer("inv_freq", inv_freq.astype(paddle.get_default_dtype()), persistable=False)
def forward(self, seqlen: int) -> paddle.Tensor:
seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype)
freqs = paddle.outer(seq, self.inv_freq)
return freqs
class SiglipEncoder(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
self.layers = nn.LayerList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
self.gradient_checkpointing = False
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def build_window_index(self, image_grid, window_size):
"""
返回:
window_indices: int64 [sum(t*h*w_valid)]
cu_seqlens_within_windows: int32 [num_windows_total*t],首位补 0 的前缀和
"""
from einops import rearrange
window_indices = list()
pad_values = -100
start_window_index = 0
cu_seqlens_within_windows = list()
for t, h, w in map(int, image_grid):
window_index = paddle.arange(t * h * w).reshape((t, h, w))
pad_h = (-h) % window_size
pad_w = (-w) % window_size
assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
window_index = rearrange(
window_index,
"t (h p1) (w p2) -> t (h w) (p1 p2)",
p1=window_size,
p2=window_size,
)
window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
window_index = window_index.reshape(-1)
window_index = window_index[window_index != pad_values]
window_indices.append(window_index + start_window_index)
cu_seqlens_within_windows.append(window_seqlens.cumsum(0) + start_window_index)
start_window_index += t * h * w
window_indices = paddle.concat(window_indices, axis=0)
cu_seqlens_within_windows = paddle.concat(cu_seqlens_within_windows, axis=0)
cu_seqlens_within_windows = F.pad(cu_seqlens_within_windows, (1, 0), value=0).astype("int32")
return window_indices, cu_seqlens_within_windows
def forward(
self,
inputs_embeds: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cu_seqlens: Optional[paddle.Tensor] = None,
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
height_position_ids: Optional[paddle.Tensor] = None,
width_position_ids: Optional[paddle.Tensor] = None,
use_rope: Optional[bool] = False,
window_size: Optional[int] = -1,
vision_or_text: str = "vision",
):
assert vision_or_text in ["vision", "text"]
use_window_attn = window_size > 0 and vision_or_text == "vision"
use_rope = (use_rope is True) and (vision_or_text == "vision")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None
if use_rope is True:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
flatten_image_grid_thw = np.array(flatten_image_grid_thw)
assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], (
flatten_image_grid_thw,
hidden_states.shape,
)
if width_position_ids is None or height_position_ids is None:
split_hids = list()
split_wids = list()
for t, h, w in flatten_image_grid_thw:
t, h, w = map(int, (t, h, w))
image_pids = paddle.arange(t * h * w) % (h * w)
sample_hids = image_pids // w
sample_wids = image_pids % w
split_hids.append(sample_hids)
split_wids.append(sample_wids)
width_position_ids = paddle.concat(split_wids, axis=0)
height_position_ids = paddle.concat(split_hids, axis=0)
window_indices, cu_seqlens_within_windows = None, None
if use_window_attn:
window_indices, cu_seqlens_within_windows = self.build_window_index(
flatten_image_grid_thw, window_size
)
reversed_window_indices = window_indices.argsort()
height_position_ids = height_position_ids[window_indices]
width_position_ids = width_position_ids[window_indices]
pids = paddle.stack([height_position_ids, width_position_ids], axis=-1).astype(paddle.int64)
max_grid_size = pids.max() + 1
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.tile((1, 2))
cos = rope_emb.cos().astype("float32")
sin = rope_emb.sin().astype("float32")
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
rope_emb = (cos, sin)
else:
rope_emb = None
window_indices, cu_seqlens_within_windows = None, None
if use_window_attn:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
assert (
sum([np.prod(x.astype("float32").cpu().numpy()) for x in flatten_image_grid_thw])
== hidden_states.shape[1]
), (flatten_image_grid_thw, hidden_states.shape)
window_indices, cu_seqlens_within_windows = self.build_window_index(
flatten_image_grid_thw, window_size
)
reversed_window_indices = window_indices.argsort()
if use_window_attn:
assert cu_seqlens_within_windows is not None
attn_cu_seqlens = cu_seqlens_within_windows
hidden_states = hidden_states[:, window_indices, :]
else:
attn_cu_seqlens = cu_seqlens
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (
(hidden_states[:, reversed_window_indices, :],) if use_window_attn else (hidden_states,)
)
layer_outputs = encoder_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens,
max_seqlen=max_seqlen,
rope_emb=rope_emb,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if use_window_attn:
hidden_states = hidden_states[:, reversed_window_indices, :]
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states
class SiglipMultiheadAttentionPoolingHead(nn.Layer):
"""Multihead Attention Pooling."""
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.probe = self.create_parameter(
shape=(1, 1, config.hidden_size),
default_initializer=paddle.nn.initializer.Normal(),
)
self.attention = nn.MultiHeadAttention(config.hidden_size, config.num_attention_heads)
self.layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state, key_padding_mask=None):
batch_size = hidden_state.shape[0]
probe = self.probe.tile((batch_size, 1, 1))
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class SiglipVisionTransformer(nn.Layer):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, epsilon=config.layer_norm_eps)
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
attention_mask=None,
sample_indices=None,
image_indices=None,
position_ids=None,
height_position_ids=None,
width_position_ids=None,
cu_seqlens=None,
padding_mask=None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
):
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
image_grid_thw=image_grid_thw,
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
image_grid_thw=image_grid_thw,
use_rope=use_rope,
height_position_ids=height_position_ids,
width_position_ids=width_position_ids,
window_size=window_size,
vision_or_text="vision",
)
last_hidden_state = self.post_layernorm(last_hidden_state)
sample_hidden_state = list()
assert cu_seqlens is not None
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
tensor = last_hidden_state[:, start:end, :].squeeze(0)
sample_hidden_state.append(tensor)
return sample_hidden_state
class SiglipVisionModel(PretrainedModel):
config_class = PaddleOCRVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: PaddleOCRVisionConfig, prefix=""):
super().__init__(config)
self.prefix_name = prefix
self.vision_model = SiglipVisionTransformer(config)
def get_input_embeddings(self) -> nn.Layer:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values,
sample_indices=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
position_ids=None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
cu_seqlens=None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
):
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
vision_return_embed_list=vision_return_embed_list,
image_grid_thw=image_grid_thw,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
return_pooler_output=return_pooler_output,
use_rope=use_rope,
window_size=window_size,
)
def load_state_dict(self, state_dict):
params_dict = dict(self.named_parameters())
for param_name, param in params_dict.items():
state_dict_key = f"{self.prefix_name}.{param_name}"
if state_dict_key not in state_dict:
if "self_attn.qkv_proj.weight" in state_dict_key:
q_weight_key = state_dict_key.replace("qkv_proj", "q_proj")
k_weight_key = state_dict_key.replace("qkv_proj", "k_proj")
v_weight_key = state_dict_key.replace("qkv_proj", "v_proj")
q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key))
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
tensor = paddle.transpose(weight_tensor, perm=[1, 0])
elif "self_attn.qkv_proj.bias" in state_dict_key:
q_bias_key = state_dict_key.replace("qkv_proj", "q_proj")
k_bias_key = state_dict_key.replace("qkv_proj", "k_proj")
v_bias_key = state_dict_key.replace("qkv_proj", "v_proj")
q_bias = get_tensor(state_dict.pop(q_bias_key))
k_bias = get_tensor(state_dict.pop(k_bias_key))
v_bias = get_tensor(state_dict.pop(v_bias_key))
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
tensor = qkv_bias
else:
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
else:
tensor = get_tensor(state_dict.pop(state_dict_key))
if param.shape != tensor.shape:
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
else:
param.copy_(tensor, False)