mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Support Paddle-OCR (#4396)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* init * update code * fix code style & disable thinking * adapt for common_engine.update_mm_requests_chunk_size * use 3d rope * use flash_attn_unpadded * opt siglip * update to be compatible with the latest codebase * fix typo * optim OCR performance * fix bug * fix bug * fix bug * fix bug * normlize name * modify xpu rope * revert logger * fix bug * fix bug * fix bug * support default_v1 * optim performance * fix bug --------- Co-authored-by: root <root@szzj-acg-tge1-fdda9.szzj.baidu.com> Co-authored-by: zhangyue66 <zhangyue66@baidu.com>
This commit is contained in:
15
fastdeploy/model_executor/models/paddleocr_vl/__init__.py
Normal file
15
fastdeploy/model_executor/models/paddleocr_vl/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
167
fastdeploy/model_executor/models/paddleocr_vl/config.py
Normal file
167
fastdeploy/model_executor/models/paddleocr_vl/config.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class PPOCRVisionConfig(PretrainedConfig):
|
||||
model_type = "paddleocr_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
|
||||
|
||||
class PaddleOCRConfig(PretrainedConfig):
|
||||
model_type = "paddleocr_vl"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"vision_config": PPOCRVisionConfig}
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
intermediate_size=11008,
|
||||
max_position_embeddings=32768,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
image_token_id=101304,
|
||||
video_token_id=101305,
|
||||
vision_start_token_id=101306,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=False,
|
||||
use_flash_attention=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
use_bias=False,
|
||||
rope_theta=10000,
|
||||
weight_share_add_bias=True,
|
||||
ignored_index=-100,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
hidden_dropout_prob=0.0,
|
||||
compression_ratio: float = 1.0,
|
||||
num_key_value_heads=None,
|
||||
max_sequence_length=None,
|
||||
tie_word_embeddings=False,
|
||||
vision_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Set default for tied embeddings if not specified.
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.use_flash_attention = use_flash_attention
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.head_dim = head_dim
|
||||
if hidden_act != "silu":
|
||||
raise NotImplementedError
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_size = hidden_size
|
||||
self.use_bias = use_bias
|
||||
self.weight_share_add_bias = weight_share_add_bias
|
||||
self.rope_theta = rope_theta
|
||||
self.ignored_index = ignored_index
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.compression_ratio = compression_ratio
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.max_sequence_length = max_sequence_length
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
# Currently, these configuration items are hard-coded
|
||||
self.fuse_rms_norm = True
|
||||
self.use_sparse_flash_attn = True
|
||||
self.use_var_len_flash_attn = False
|
||||
self.scale_qk_coeff = 1.0
|
||||
self.fuse_softmax_mask = False
|
||||
self.use_sparse_head_and_loss_fn = False
|
||||
self.use_recompute_loss_fn = False
|
||||
self.use_fused_head_and_loss_fn = False
|
||||
self.fuse_linear = False
|
||||
self.token_balance_seqlen = False
|
||||
self.use_rmsnorm = True
|
||||
self.fuse_ln = False
|
||||
self.cachekv_quant = False
|
||||
self.fuse_swiglu = False
|
||||
455
fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py
Normal file
455
fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||
support_graph_optimization,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
||||
from fastdeploy.model_executor.models.model_base import (
|
||||
ModelCategory,
|
||||
ModelForCasualLM,
|
||||
ModelRegistry,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
|
||||
from .projector import Projector
|
||||
from .siglip import SiglipVisionModel
|
||||
|
||||
|
||||
@support_graph_optimization
|
||||
class PaddleOCRVLModel(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = fd_config.model_config
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
fd_config.model_config.pretrained_config.prefix_name = "model"
|
||||
self._dtype = fd_config.model_config.torch_dtype
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
fd_config=fd_config,
|
||||
num_embeddings=fd_config.model_config.vocab_size,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
params_dtype=self._dtype,
|
||||
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.LayerList(
|
||||
[
|
||||
Ernie4_5_DecoderLayer(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer.self_attn.attn = Attention(
|
||||
fd_config=fd_config,
|
||||
layer_id=i,
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}.self_attn",
|
||||
use_neox_rotary_style=True,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
self.embed_tokens.load_state_dict(state_dict)
|
||||
self.norm.load_state_dict(state_dict)
|
||||
for i in range(self.num_layers):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
|
||||
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_embeddings: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
hidden_states = input_embeddings
|
||||
residual = None
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ModelRegistry.register_model_class(
|
||||
architecture="PaddleOCRVLForConditionalGeneration",
|
||||
module_name="paddleocr_vl.paddleocr_vl",
|
||||
category=ModelCategory.MULTIMODAL,
|
||||
primary_use=ModelCategory.MULTIMODAL,
|
||||
)
|
||||
class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
def __init__(self, fd_config):
|
||||
super().__init__(fd_config)
|
||||
|
||||
config = fd_config.model_config
|
||||
self.config = config
|
||||
self.mlp_AR = Projector(config, config.vision_config, prefix="mlp_AR")
|
||||
self.visual = SiglipVisionModel(config.vision_config, prefix="visual")
|
||||
self.model = PaddleOCRVLModel(fd_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
fd_config=fd_config,
|
||||
embedding_dim=fd_config.model_config.hidden_size,
|
||||
num_embeddings=fd_config.model_config.vocab_size,
|
||||
prefix="lm_head",
|
||||
)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
if envs.FD_ENABLE_MAX_PREFILL:
|
||||
max_length = fd_config.scheduler_config.max_num_seqs * fd_config.model_config.max_model_len
|
||||
else:
|
||||
max_length = fd_config.model_config.max_model_len
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[max_length, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
"""
|
||||
Load model parameters from a given weights_iterator object.
|
||||
|
||||
Args:
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("up_gate_proj", "gate_proj", "gate"),
|
||||
("up_gate_proj", "up_proj", "up"),
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
loaded_weight_name = (
|
||||
self.process_weights_before_loading_fn(loaded_weight_name)
|
||||
if getattr(self, "process_weights_before_loading_fn", None)
|
||||
else loaded_weight_name
|
||||
)
|
||||
if loaded_weight_name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
model_param_name = loaded_weight_name
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight)
|
||||
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.visual.load_state_dict(state_dict)
|
||||
self.projector.load_state_dict(state_dict)
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def projector(self):
|
||||
return self.mlp_AR
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
return "PaddleOCRVLForConditionalGeneration"
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
logits[:, self.vocab_size :] = -float("inf")
|
||||
|
||||
return logits
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor] = None,
|
||||
) -> paddle.Tensor:
|
||||
input_embeddings = self.model.get_input_embeddings(ids_remove_padding=ids_remove_padding)
|
||||
image_mask = ids_remove_padding == self.model.config.image_token_id
|
||||
image_token_num = image_mask.sum()
|
||||
|
||||
if image_token_num > 0:
|
||||
input_embeddings[image_mask] = image_features.cast(self._dtype)
|
||||
return input_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
image_features: Optional[paddle.Tensor],
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
input_embeddings = self.get_input_embeddings(
|
||||
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PaddleOCRVLPretrainedModel(PretrainedModel):
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "PaddleOCRVLForConditionalGeneration"
|
||||
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
|
||||
weight_infos = [
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(".embed_tokens.weight", False),
|
||||
WeightMeta("lm_head.weight", True),
|
||||
]
|
||||
|
||||
weight_vison = [
|
||||
# resampler_model
|
||||
WeightMeta("ernie.resampler_model.spatial_linear.0.weight", False),
|
||||
WeightMeta("resampler_model.spatial_linear.0.weight", False),
|
||||
# vision
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc2.weight", False),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.weight", True),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.bias", True),
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
build_expanded_keys,
|
||||
has_prefix,
|
||||
split_or_merge_func_v1,
|
||||
)
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
vision_fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.vision_config.get("num_heads"),
|
||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||
head_dim=config.vision_config.get("hidden_size") // config.vision_config.get("num_heads"),
|
||||
)
|
||||
|
||||
def get_tensor_parallel_split_mappings(
|
||||
num_layers: int,
|
||||
moe_num_experts: list[int],
|
||||
moe_layer_start_index: int,
|
||||
prefix_name: str,
|
||||
):
|
||||
base_actions = {}
|
||||
for weight_name, is_column, extra in cls.weight_infos:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({extra.value: True} if extra else {}),
|
||||
}
|
||||
|
||||
if "lm_head.weight" in weight_name or weight_name == "":
|
||||
key = weight_name
|
||||
elif not has_prefix(prefix_name, weight_name):
|
||||
key = f"{prefix_name}{weight_name}"
|
||||
else:
|
||||
key = weight_name
|
||||
base_actions[key] = partial(fn, **params)
|
||||
final_actions = {}
|
||||
final_actions = build_expanded_keys(
|
||||
base_actions,
|
||||
num_layers,
|
||||
(moe_layer_start_index if moe_layer_start_index > 0 else num_layers),
|
||||
text_num_experts=moe_num_experts[0],
|
||||
img_num_experts=moe_num_experts[1],
|
||||
)
|
||||
return final_actions
|
||||
|
||||
def get_vison_parallel_split_mappings(num_layers: int):
|
||||
base_actions = {}
|
||||
for weight_name, is_column, extra in cls.weight_vison:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({extra.value: True} if extra else {}),
|
||||
}
|
||||
base_actions[weight_name] = partial(vision_fn, **params)
|
||||
final_actions = {}
|
||||
final_actions = build_expanded_keys(
|
||||
base_actions,
|
||||
num_layers,
|
||||
)
|
||||
return final_actions
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(
|
||||
config.num_hidden_layers,
|
||||
config.moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
config.prefix_name,
|
||||
)
|
||||
vision_mappings = get_vison_parallel_split_mappings(config.vision_config.get("depth"))
|
||||
|
||||
return {**mappings, **vision_mappings}
|
||||
107
fastdeploy/model_executor/models/paddleocr_vl/projector.py
Normal file
107
fastdeploy/model_executor/models/paddleocr_vl/projector.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
|
||||
class GELUActivation(nn.Layer):
|
||||
"""
|
||||
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
||||
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||||
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
||||
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, use_gelu_python: bool = False):
|
||||
super().__init__()
|
||||
if use_gelu_python:
|
||||
self.act = self._gelu_python
|
||||
else:
|
||||
self.act = nn.functional.gelu
|
||||
|
||||
def _gelu_python(self, input):
|
||||
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
|
||||
|
||||
def forward(self, input):
|
||||
return self.act(input)
|
||||
|
||||
|
||||
class Projector(nn.Layer):
|
||||
|
||||
def __init__(self, text_config, vision_config, prefix=""):
|
||||
super().__init__()
|
||||
self.prefix_name = prefix
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.merge_kernel_size = (2, 2)
|
||||
|
||||
self.hidden_size = self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1]
|
||||
|
||||
self.pre_norm = nn.LayerNorm(self.vision_config.hidden_size, epsilon=1e-05)
|
||||
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.act = GELUActivation()
|
||||
self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features, image_grid_thw):
|
||||
m1, m2 = self.merge_kernel_size
|
||||
if isinstance(image_features, (list, tuple)):
|
||||
processed_features = list()
|
||||
for image_feature, image_grid in zip(image_features, image_grid_thw):
|
||||
image_feature = self.pre_norm(image_feature) # shape: (T*H*W, D)
|
||||
t, h, w = image_grid
|
||||
from einops import rearrange
|
||||
|
||||
image_feature = rearrange(
|
||||
image_feature,
|
||||
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
|
||||
t=int(t),
|
||||
h=int(h // m1),
|
||||
p1=int(m1),
|
||||
w=int(w // m2),
|
||||
p2=int(m2),
|
||||
)
|
||||
hidden_states = self.linear_1(image_feature)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
processed_features.append(hidden_states)
|
||||
|
||||
return processed_features
|
||||
|
||||
dim = image_features.shape[-1]
|
||||
image_features = paddle.reshape(image_features, [-1, dim])
|
||||
hidden_states = self.pre_norm(image_features)
|
||||
hidden_states = paddle.reshape(hidden_states, [-1, self.hidden_size])
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
|
||||
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||
if param.shape != tensor.shape:
|
||||
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||
else:
|
||||
param.copy_(tensor, False)
|
||||
740
fastdeploy/model_executor/models/paddleocr_vl/siglip.py
Normal file
740
fastdeploy/model_executor/models/paddleocr_vl/siglip.py
Normal file
@@ -0,0 +1,740 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
from paddleformers.transformers.activations import ACT2FN
|
||||
from paddleformers.transformers.model_utils import PretrainedModel
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import slice_fn
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from .config import PPOCRVisionConfig
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
Dh = x.shape[-1]
|
||||
x1 = x[..., : Dh // 2]
|
||||
x2 = x[..., Dh // 2 :]
|
||||
return paddle.concat([-x2, x1], axis=-1)
|
||||
|
||||
|
||||
def _ensure_cos_sin_dim(cos, sin, dim_needed):
|
||||
last = cos.shape[-1]
|
||||
if last == dim_needed:
|
||||
return cos, sin
|
||||
elif last * 2 == dim_needed:
|
||||
cos = paddle.concat([cos, cos], axis=-1)
|
||||
sin = paddle.concat([sin, sin], axis=-1)
|
||||
return cos, sin
|
||||
else:
|
||||
raise ValueError(f"Unexpected cos/sin last-dim: {last}, expected {dim_needed} or {dim_needed//2}")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(x, cos, sin):
|
||||
orig_dtype = x.dtype
|
||||
x = x.astype("float32")
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed.astype(orig_dtype)
|
||||
|
||||
|
||||
class QKVLinear(nn.Linear):
|
||||
def __init__(self, config, in_features, out_features, weight_attr=None, bias_attr=None):
|
||||
super().__init__(in_features, out_features, weight_attr, bias_attr)
|
||||
self.config = config
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert self.head_dim * self.num_heads == self.embed_dim
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
if loaded_shard_id == "q":
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.num_heads * self.head_dim
|
||||
elif loaded_shard_id == "k":
|
||||
param_shard_offset = self.num_heads * self.head_dim
|
||||
param_shard_size = self.num_heads * self.head_dim
|
||||
else:
|
||||
# loaded_shard_id == "v"
|
||||
param_shard_offset = self.num_heads * self.head_dim * 2
|
||||
param_shard_size = self.num_heads * self.head_dim
|
||||
|
||||
param = slice_fn(param, self.out_features, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
# Ensure loaded weight dtype matches model param dtype
|
||||
if loaded_weight.dtype != param.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
|
||||
class SiglipAttention(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
assert self.head_dim * self.num_heads == self.embed_dim
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVLinear(config, self.embed_dim, self.embed_dim * 3, bias_attr=True)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
is_current_sm_supported = cc >= 90
|
||||
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
||||
if is_current_sm_supported and is_paddle_supported:
|
||||
self.flash_attn_func = flash_attention_v3_varlen
|
||||
self.flash_attn_kwargs = {}
|
||||
else:
|
||||
self.flash_attn_func = flash_attn_unpadded
|
||||
self.flash_attn_kwargs = {"scale": self.scale, "training": False}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: paddle.Tensor, # [B, L, D]
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cu_seqlens: Optional[List[paddle.Tensor]] = None,
|
||||
max_seqlen: Optional[paddle.Tensor] = None,
|
||||
rope_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, # (cos, sin)
|
||||
):
|
||||
B, seq_length, D = hidden_states.shape
|
||||
|
||||
qkv = (
|
||||
self.qkv_proj(hidden_states)
|
||||
.reshape(
|
||||
[
|
||||
seq_length,
|
||||
3,
|
||||
self.num_heads,
|
||||
-1,
|
||||
]
|
||||
)
|
||||
.transpose(perm=[1, 0, 2, 3])
|
||||
)
|
||||
q, k, v = qkv.unbind(axis=0)
|
||||
cos, sin = rope_emb
|
||||
|
||||
# --------
|
||||
q = apply_rotary_pos_emb_vision(q, cos, sin)
|
||||
k = apply_rotary_pos_emb_vision(k, cos, sin)
|
||||
|
||||
attn_output = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
causal=False,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
# --------
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class SiglipVisionEmbeddings(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size # 1152
|
||||
self.image_size = config.image_size # 384
|
||||
self.patch_size = config.patch_size # 14
|
||||
|
||||
self.patch_embedding = nn.Conv2D(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="VALID",
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2 # 729
|
||||
self.num_positions = self.num_patches
|
||||
self.cache_position_embedding = dict()
|
||||
self.cache_position_count = dict()
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
|
||||
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
paddle.arange(self.num_positions).unsqueeze(0),
|
||||
persistable=False,
|
||||
)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings, height: int, width: int, is_after_patchify: bool = False):
|
||||
|
||||
num_positions = self.position_embedding.weight.shape[0]
|
||||
|
||||
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
if is_after_patchify:
|
||||
new_height = height
|
||||
new_width = width
|
||||
else:
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = paddle.to_tensor(num_positions**0.5, dtype=paddle.int64)
|
||||
patch_pos_embed = patch_pos_embed.reshape((1, sqrt_num_positions, sqrt_num_positions, dim))
|
||||
patch_pos_embed = patch_pos_embed.transpose((0, 3, 1, 2))
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.transpose((0, 2, 3, 1)).reshape((1, -1, dim))
|
||||
return patch_pos_embed
|
||||
|
||||
@staticmethod
|
||||
def flatten_list(image_grid_thw):
|
||||
tmp_image_grid_thw = list()
|
||||
for image_grid in image_grid_thw:
|
||||
if isinstance(image_grid, list):
|
||||
tmp_image_grid_thw.extend(image_grid)
|
||||
else:
|
||||
tmp_image_grid_thw.append(image_grid)
|
||||
return tmp_image_grid_thw
|
||||
|
||||
def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache=20):
|
||||
grid = (h, w)
|
||||
if grid in self.cache_position_embedding:
|
||||
self.cache_position_count[grid] += 1
|
||||
return self.cache_position_embedding[grid]
|
||||
|
||||
if len(self.cache_position_embedding) >= max_cache:
|
||||
min_hit_grid = min(self.cache_position_count, key=self.cache_position_count.get)
|
||||
self.cache_position_count.pop(min_hit_grid)
|
||||
self.cache_position_embedding.pop(min_hit_grid)
|
||||
|
||||
position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
|
||||
self.cache_position_count[grid] = 1
|
||||
self.cache_position_embedding[grid] = position_embedding
|
||||
return position_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: paddle.Tensor, # [B, L, C, H, W]
|
||||
position_ids: Optional[paddle.Tensor] = None, # [B or 1, S]
|
||||
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
if pixel_values.dim() == 4:
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
if pixel_values.dim() == 5:
|
||||
assert position_ids is not None
|
||||
from einops import rearrange
|
||||
|
||||
batch_size, squence_len, channel, height, width = pixel_values.shape
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
embeddings = patch_embeds.flatten(-2).squeeze(-1)
|
||||
embeddings = rearrange(embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len)
|
||||
# todo: not debug
|
||||
if interpolate_pos_encoding and image_grid_thw is not None:
|
||||
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
|
||||
flatten_image_grid_thw = np.array(flatten_image_grid_thw)
|
||||
assert batch_size == 1
|
||||
start = 0
|
||||
|
||||
assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], (
|
||||
flatten_image_grid_thw,
|
||||
embeddings.shape,
|
||||
)
|
||||
embeddings = embeddings.squeeze(0)
|
||||
tmp_embeddings = list()
|
||||
for image_grid in image_grid_thw:
|
||||
t, h, w = image_grid
|
||||
end = start + t * h * w
|
||||
image_embeddings = embeddings[int(start) : int(end), :]
|
||||
position_embedding = (
|
||||
self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).tile((t, 1))
|
||||
).astype(image_embeddings.dtype)
|
||||
image_embeddings = image_embeddings + position_embedding
|
||||
tmp_embeddings.append(image_embeddings)
|
||||
start = end
|
||||
embeddings = paddle.concat(tmp_embeddings, axis=0).unsqueeze(0)
|
||||
else:
|
||||
embeddings = embeddings + self.packing_position_embedding(position_ids)
|
||||
return embeddings
|
||||
else:
|
||||
raise NotImplementedError(str(pixel_values.shape))
|
||||
|
||||
|
||||
class SiglipMLP(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.hidden_act == "gelu_pytorch_tanh":
|
||||
config.hidden_act = "silu"
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipEncoderLayer(paddle.nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.layer_norm1 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
|
||||
self.self_attn = SiglipAttention(config)
|
||||
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
|
||||
# @paddle.jit.to_static
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
output_attentions=False,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_emb=None,
|
||||
):
|
||||
|
||||
residual = hidden_states
|
||||
############################
|
||||
ln1_out = self.layer_norm1(hidden_states)
|
||||
|
||||
x = self.self_attn(
|
||||
hidden_states=ln1_out,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rope_emb=rope_emb,
|
||||
)
|
||||
|
||||
hs_post_attn = residual + x
|
||||
|
||||
residual = hs_post_attn
|
||||
ln2_out = self.layer_norm2(residual)
|
||||
|
||||
mlp_out = self.mlp(ln2_out)
|
||||
|
||||
hidden_states_out = residual + mlp_out
|
||||
|
||||
outputs = (hidden_states_out,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SigLIPRotaryEmbedding(nn.Layer):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.rope_init()
|
||||
|
||||
def rope_init(self):
|
||||
arange = paddle.arange(0, self.dim, 2, dtype="float32")
|
||||
inv_freq = 1.0 / (self.theta ** (arange / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq.astype(paddle.get_default_dtype()), persistable=False)
|
||||
|
||||
def forward(self, seqlen: int) -> paddle.Tensor:
|
||||
seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype)
|
||||
freqs = paddle.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class SiglipEncoder(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
head_dim = embed_dim // num_heads
|
||||
self.layers = nn.LayerList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@staticmethod
|
||||
def flatten_list(image_grid_thw):
|
||||
tmp_image_grid_thw = list()
|
||||
for image_grid in image_grid_thw:
|
||||
if isinstance(image_grid, list):
|
||||
tmp_image_grid_thw.extend(image_grid)
|
||||
else:
|
||||
tmp_image_grid_thw.append(image_grid)
|
||||
return tmp_image_grid_thw
|
||||
|
||||
def build_window_index(self, image_grid, window_size):
|
||||
"""
|
||||
返回:
|
||||
window_indices: int64 [sum(t*h*w_valid)]
|
||||
cu_seqlens_within_windows: int32 [num_windows_total*t],首位补 0 的前缀和
|
||||
"""
|
||||
from einops import rearrange
|
||||
|
||||
window_indices = list()
|
||||
pad_values = -100
|
||||
start_window_index = 0
|
||||
cu_seqlens_within_windows = list()
|
||||
|
||||
for t, h, w in map(int, image_grid):
|
||||
window_index = paddle.arange(t * h * w).reshape((t, h, w))
|
||||
pad_h = (-h) % window_size
|
||||
pad_w = (-w) % window_size
|
||||
assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
|
||||
window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
|
||||
window_index = rearrange(
|
||||
window_index,
|
||||
"t (h p1) (w p2) -> t (h w) (p1 p2)",
|
||||
p1=window_size,
|
||||
p2=window_size,
|
||||
)
|
||||
window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
|
||||
window_index = window_index.reshape(-1)
|
||||
window_index = window_index[window_index != pad_values]
|
||||
window_indices.append(window_index + start_window_index)
|
||||
cu_seqlens_within_windows.append(window_seqlens.cumsum(0) + start_window_index)
|
||||
start_window_index += t * h * w
|
||||
window_indices = paddle.concat(window_indices, axis=0)
|
||||
cu_seqlens_within_windows = paddle.concat(cu_seqlens_within_windows, axis=0)
|
||||
cu_seqlens_within_windows = F.pad(cu_seqlens_within_windows, (1, 0), value=0).astype("int32")
|
||||
return window_indices, cu_seqlens_within_windows
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: paddle.Tensor,
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cu_seqlens: Optional[paddle.Tensor] = None,
|
||||
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
|
||||
height_position_ids: Optional[paddle.Tensor] = None,
|
||||
width_position_ids: Optional[paddle.Tensor] = None,
|
||||
use_rope: Optional[bool] = False,
|
||||
window_size: Optional[int] = -1,
|
||||
vision_or_text: str = "vision",
|
||||
):
|
||||
assert vision_or_text in ["vision", "text"]
|
||||
use_window_attn = window_size > 0 and vision_or_text == "vision"
|
||||
use_rope = (use_rope is True) and (vision_or_text == "vision")
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
hidden_states = inputs_embeds
|
||||
attention_mask = attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None
|
||||
|
||||
if use_rope is True:
|
||||
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
|
||||
flatten_image_grid_thw = np.array(flatten_image_grid_thw)
|
||||
assert sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1], (
|
||||
flatten_image_grid_thw,
|
||||
hidden_states.shape,
|
||||
)
|
||||
|
||||
if width_position_ids is None or height_position_ids is None:
|
||||
split_hids = list()
|
||||
split_wids = list()
|
||||
for t, h, w in flatten_image_grid_thw:
|
||||
t, h, w = map(int, (t, h, w))
|
||||
image_pids = paddle.arange(t * h * w) % (h * w)
|
||||
sample_hids = image_pids // w
|
||||
sample_wids = image_pids % w
|
||||
split_hids.append(sample_hids)
|
||||
split_wids.append(sample_wids)
|
||||
width_position_ids = paddle.concat(split_wids, axis=0)
|
||||
height_position_ids = paddle.concat(split_hids, axis=0)
|
||||
|
||||
window_indices, cu_seqlens_within_windows = None, None
|
||||
|
||||
if use_window_attn:
|
||||
window_indices, cu_seqlens_within_windows = self.build_window_index(
|
||||
flatten_image_grid_thw, window_size
|
||||
)
|
||||
reversed_window_indices = window_indices.argsort()
|
||||
height_position_ids = height_position_ids[window_indices]
|
||||
width_position_ids = width_position_ids[window_indices]
|
||||
|
||||
pids = paddle.stack([height_position_ids, width_position_ids], axis=-1).astype(paddle.int64)
|
||||
max_grid_size = pids.max() + 1
|
||||
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
|
||||
|
||||
rope_emb = rope_emb_max_grid[pids].flatten(1)
|
||||
rope_emb = rope_emb.tile((1, 2))
|
||||
cos = rope_emb.cos().astype("float32")
|
||||
sin = rope_emb.sin().astype("float32")
|
||||
cos = cos.unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2)
|
||||
rope_emb = (cos, sin)
|
||||
else:
|
||||
rope_emb = None
|
||||
|
||||
window_indices, cu_seqlens_within_windows = None, None
|
||||
|
||||
if use_window_attn:
|
||||
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
|
||||
assert (
|
||||
sum([np.prod(x.astype("float32").cpu().numpy()) for x in flatten_image_grid_thw])
|
||||
== hidden_states.shape[1]
|
||||
), (flatten_image_grid_thw, hidden_states.shape)
|
||||
|
||||
window_indices, cu_seqlens_within_windows = self.build_window_index(
|
||||
flatten_image_grid_thw, window_size
|
||||
)
|
||||
reversed_window_indices = window_indices.argsort()
|
||||
|
||||
if use_window_attn:
|
||||
assert cu_seqlens_within_windows is not None
|
||||
attn_cu_seqlens = cu_seqlens_within_windows
|
||||
hidden_states = hidden_states[:, window_indices, :]
|
||||
else:
|
||||
attn_cu_seqlens = cu_seqlens
|
||||
|
||||
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (
|
||||
(hidden_states[:, reversed_window_indices, :],) if use_window_attn else (hidden_states,)
|
||||
)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
cu_seqlens=attn_cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
rope_emb=rope_emb,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if use_window_attn:
|
||||
hidden_states = hidden_states[:, reversed_window_indices, :]
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipMultiheadAttentionPoolingHead(nn.Layer):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
def __init__(self, config: PPOCRVisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.probe = self.create_parameter(
|
||||
shape=(1, 1, config.hidden_size),
|
||||
default_initializer=paddle.nn.initializer.Normal(),
|
||||
)
|
||||
self.attention = nn.MultiHeadAttention(config.hidden_size, config.num_attention_heads)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
|
||||
def forward(self, hidden_state, key_padding_mask=None):
|
||||
batch_size = hidden_state.shape[0]
|
||||
probe = self.probe.tile((batch_size, 1, 1))
|
||||
|
||||
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||
|
||||
residual = hidden_state
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = residual + self.mlp(hidden_state)
|
||||
|
||||
return hidden_state[:, 0]
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Layer):
|
||||
def __init__(self, config: PPOCRVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(config)
|
||||
self.encoder = SiglipEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, epsilon=config.layer_norm_eps)
|
||||
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
|
||||
if self.use_head:
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
attention_mask=None,
|
||||
sample_indices=None,
|
||||
image_indices=None,
|
||||
position_ids=None,
|
||||
height_position_ids=None,
|
||||
width_position_ids=None,
|
||||
cu_seqlens=None,
|
||||
padding_mask=None,
|
||||
vision_return_embed_list: Optional[bool] = False,
|
||||
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
|
||||
return_pooler_output: Optional[bool] = True,
|
||||
use_rope: Optional[bool] = False,
|
||||
window_size: Optional[bool] = -1,
|
||||
):
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
position_ids=position_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
image_grid_thw=image_grid_thw,
|
||||
use_rope=use_rope,
|
||||
height_position_ids=height_position_ids,
|
||||
width_position_ids=width_position_ids,
|
||||
window_size=window_size,
|
||||
vision_or_text="vision",
|
||||
)
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
sample_hidden_state = list()
|
||||
assert cu_seqlens is not None
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
start = cu_seqlens[i]
|
||||
end = cu_seqlens[i + 1]
|
||||
tensor = last_hidden_state[:, start:end, :].squeeze(0)
|
||||
sample_hidden_state.append(tensor)
|
||||
|
||||
return sample_hidden_state
|
||||
|
||||
|
||||
class SiglipVisionModel(PretrainedModel):
|
||||
config_class = PPOCRVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: PPOCRVisionConfig, prefix=""):
|
||||
super().__init__(config)
|
||||
self.prefix_name = prefix
|
||||
self.vision_model = SiglipVisionTransformer(config)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Layer:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
sample_indices=None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
position_ids=None,
|
||||
vision_return_embed_list: Optional[bool] = False,
|
||||
image_grid_thw: Optional[List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]] = None,
|
||||
cu_seqlens=None,
|
||||
return_pooler_output: Optional[bool] = True,
|
||||
use_rope: Optional[bool] = False,
|
||||
window_size: Optional[bool] = -1,
|
||||
):
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
position_ids=position_ids,
|
||||
vision_return_embed_list=vision_return_embed_list,
|
||||
image_grid_thw=image_grid_thw,
|
||||
sample_indices=sample_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
return_pooler_output=return_pooler_output,
|
||||
use_rope=use_rope,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
if "self_attn.qkv_proj.weight" in state_dict_key:
|
||||
q_weight_key = state_dict_key.replace("qkv_proj", "q_proj")
|
||||
k_weight_key = state_dict_key.replace("qkv_proj", "k_proj")
|
||||
v_weight_key = state_dict_key.replace("qkv_proj", "v_proj")
|
||||
q_tensor = get_tensor(state_dict.pop(q_weight_key))
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
|
||||
tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
elif "self_attn.qkv_proj.bias" in state_dict_key:
|
||||
q_bias_key = state_dict_key.replace("qkv_proj", "q_proj")
|
||||
k_bias_key = state_dict_key.replace("qkv_proj", "k_proj")
|
||||
v_bias_key = state_dict_key.replace("qkv_proj", "v_proj")
|
||||
q_bias = get_tensor(state_dict.pop(q_bias_key))
|
||||
k_bias = get_tensor(state_dict.pop(k_bias_key))
|
||||
v_bias = get_tensor(state_dict.pop(v_bias_key))
|
||||
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
|
||||
tensor = qkv_bias
|
||||
else:
|
||||
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
|
||||
else:
|
||||
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||
if param.shape != tensor.shape:
|
||||
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||
else:
|
||||
param.copy_(tensor, False)
|
||||
Reference in New Issue
Block a user