mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] Delete PaddleOCR Useless Function (#4815)
* fix paddleocr prefix cache bug * add test for paddleocr_vl * disable prefix-caching in ocr * add test for paddleocr_vl * Fix top_p for rejection sampling * delete useless func for paddleocr --------- Co-authored-by: ming1753 <ideaminghp@163.com> Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
This commit is contained in:
@@ -15,15 +15,12 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
@@ -95,21 +92,6 @@ class PaddleOCRVLModel(nn.Layer):
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
self.embed_tokens.load_state_dict(state_dict)
|
||||
self.norm.load_state_dict(state_dict)
|
||||
for i in range(self.num_layers):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor:
|
||||
return self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
@@ -364,92 +346,3 @@ class PaddleOCRVLPretrainedModel(PretrainedModel):
|
||||
tsm.GQA,
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
build_expanded_keys,
|
||||
has_prefix,
|
||||
split_or_merge_func_v1,
|
||||
)
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
vision_fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.vision_config.get("num_heads"),
|
||||
num_key_value_heads=config.vision_config.get("num_heads"),
|
||||
head_dim=config.vision_config.get("hidden_size") // config.vision_config.get("num_heads"),
|
||||
)
|
||||
|
||||
def get_tensor_parallel_split_mappings(
|
||||
num_layers: int,
|
||||
moe_num_experts: list[int],
|
||||
moe_layer_start_index: int,
|
||||
prefix_name: str,
|
||||
):
|
||||
base_actions = {}
|
||||
for weight_name, is_column, extra in cls.weight_infos:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({extra.value: True} if extra else {}),
|
||||
}
|
||||
|
||||
if "lm_head.weight" in weight_name or weight_name == "":
|
||||
key = weight_name
|
||||
elif not has_prefix(prefix_name, weight_name):
|
||||
key = f"{prefix_name}{weight_name}"
|
||||
else:
|
||||
key = weight_name
|
||||
base_actions[key] = partial(fn, **params)
|
||||
final_actions = {}
|
||||
final_actions = build_expanded_keys(
|
||||
base_actions,
|
||||
num_layers,
|
||||
(moe_layer_start_index if moe_layer_start_index > 0 else num_layers),
|
||||
text_num_experts=moe_num_experts[0],
|
||||
img_num_experts=moe_num_experts[1],
|
||||
)
|
||||
return final_actions
|
||||
|
||||
def get_vison_parallel_split_mappings(num_layers: int):
|
||||
base_actions = {}
|
||||
for weight_name, is_column, extra in cls.weight_vison:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({extra.value: True} if extra else {}),
|
||||
}
|
||||
base_actions[weight_name] = partial(vision_fn, **params)
|
||||
final_actions = {}
|
||||
final_actions = build_expanded_keys(
|
||||
base_actions,
|
||||
num_layers,
|
||||
)
|
||||
return final_actions
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(
|
||||
config.num_hidden_layers,
|
||||
config.moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
config.prefix_name,
|
||||
)
|
||||
vision_mappings = get_vison_parallel_split_mappings(config.vision_config.get("depth"))
|
||||
|
||||
return {**mappings, **vision_mappings}
|
||||
|
||||
@@ -110,15 +110,3 @@ class Projector(nn.Layer):
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
|
||||
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||
if param.shape != tensor.shape:
|
||||
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||
else:
|
||||
param.copy_(tensor, False)
|
||||
|
||||
@@ -736,35 +736,3 @@ class SiglipVisionModel(PretrainedModel):
|
||||
use_rope=use_rope,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for param_name, param in params_dict.items():
|
||||
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
if "self_attn.qkv_proj.weight" in state_dict_key:
|
||||
q_weight_key = state_dict_key.replace("qkv_proj", "q_proj")
|
||||
k_weight_key = state_dict_key.replace("qkv_proj", "k_proj")
|
||||
v_weight_key = state_dict_key.replace("qkv_proj", "v_proj")
|
||||
q_tensor = get_tensor(state_dict.pop(q_weight_key))
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], axis=-1).transpose([1, 0])
|
||||
tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
elif "self_attn.qkv_proj.bias" in state_dict_key:
|
||||
q_bias_key = state_dict_key.replace("qkv_proj", "q_proj")
|
||||
k_bias_key = state_dict_key.replace("qkv_proj", "k_proj")
|
||||
v_bias_key = state_dict_key.replace("qkv_proj", "v_proj")
|
||||
q_bias = get_tensor(state_dict.pop(q_bias_key))
|
||||
k_bias = get_tensor(state_dict.pop(k_bias_key))
|
||||
v_bias = get_tensor(state_dict.pop(v_bias_key))
|
||||
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
|
||||
tensor = qkv_bias
|
||||
else:
|
||||
raise ValueError(f"The key {state_dict_key} does not exist in state_dict. ")
|
||||
else:
|
||||
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||
if param.shape != tensor.shape:
|
||||
raise ValueError(f"{state_dict_key} param.shape={param.shape} tensor.shape={tensor.shape}")
|
||||
else:
|
||||
param.copy_(tensor, False)
|
||||
|
||||
Reference in New Issue
Block a user