mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 02:53:26 +08:00
rl update (#2861)
This commit is contained in:
@@ -519,43 +519,6 @@ class DFNRopeVisionTransformerPretrainedModel(PretrainedModel):
|
||||
"""
|
||||
return self.blocks[0].mlp.fc2.weight.dtype
|
||||
|
||||
def get_name_mappings_to_training(self, ):
|
||||
""" get_name_mappings_to_training """
|
||||
infer_to_train = {}
|
||||
|
||||
# vit train names
|
||||
vit_names = [
|
||||
"vision_model.patch_embed.proj.weight", "vision_model.ln.weight",
|
||||
"vision_model.ln.bias"
|
||||
]
|
||||
|
||||
vit_layer = 32
|
||||
for layer_idx in range(vit_layer):
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.norm1.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.norm1.bias")
|
||||
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.norm2.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.norm2.bias")
|
||||
|
||||
vit_names.append(
|
||||
f"vision_model.blocks.{layer_idx}.attn.qkv.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.attn.qkv.bias")
|
||||
|
||||
vit_names.append(
|
||||
f"vision_model.blocks.{layer_idx}.attn.proj.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.attn.proj.bias")
|
||||
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc1.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc1.bias")
|
||||
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc2.weight")
|
||||
vit_names.append(f"vision_model.blocks.{layer_idx}.mlp.fc2.bias")
|
||||
|
||||
for train_name in vit_names:
|
||||
infer_to_train[train_name] = train_name
|
||||
|
||||
return infer_to_train
|
||||
|
||||
def rot_pos_emb(self, grid_thw, num_pad=0):
|
||||
"""_summary_
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
model_config.spatial_conv_size,
|
||||
model_config.temporal_conv_size,
|
||||
config=model_config,
|
||||
prefix_name="ernie.resampler_model",
|
||||
prefix_name="resampler_model",
|
||||
)
|
||||
resampler_model = paddle.amp.decorate(
|
||||
models=resampler_model, level="O2", dtype="bfloat16"
|
||||
|
||||
@@ -210,31 +210,6 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
||||
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
||||
|
||||
def get_name_mappings_to_training(self, ):
|
||||
""" get_name_mappings_to_training """
|
||||
infer_to_train = {}
|
||||
resampler_names = [
|
||||
"ernie.resampler_model.spatial_linear.0.weight",
|
||||
"ernie.resampler_model.spatial_linear.0.bias",
|
||||
"ernie.resampler_model.spatial_linear.2.weight",
|
||||
"ernie.resampler_model.spatial_linear.2.bias",
|
||||
"ernie.resampler_model.spatial_linear.3.weight",
|
||||
"ernie.resampler_model.spatial_linear.3.bias",
|
||||
"ernie.resampler_model.temporal_linear.0.weight",
|
||||
"ernie.resampler_model.temporal_linear.0.bias",
|
||||
"ernie.resampler_model.temporal_linear.2.weight",
|
||||
"ernie.resampler_model.temporal_linear.2.bias",
|
||||
"ernie.resampler_model.temporal_linear.3.weight",
|
||||
"ernie.resampler_model.temporal_linear.3.bias",
|
||||
"ernie.resampler_model.mlp.weight",
|
||||
"ernie.resampler_model.mlp.bias",
|
||||
"ernie.resampler_model.after_norm.weight",
|
||||
]
|
||||
for train_name in resampler_names:
|
||||
infer_to_train[train_name[len("ernie."):]] = train_name
|
||||
|
||||
return infer_to_train
|
||||
|
||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||
"""
|
||||
Linear 前的 reshape,为了让 Linear 能模仿 conv 的感受野
|
||||
@@ -376,9 +351,11 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
for param_name, param in params_dict.items():
|
||||
state_dict_key = f"{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
raise ValueError(
|
||||
f"The key {state_dict_key} does not exist in state_dict. "
|
||||
)
|
||||
state_dict_key = f"ernie.{self.prefix_name}.{param_name}"
|
||||
if state_dict_key not in state_dict:
|
||||
raise ValueError(
|
||||
f"The key {state_dict_key} does not exist in state_dict. "
|
||||
)
|
||||
tensor = get_tensor(state_dict.pop(state_dict_key))
|
||||
if param.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user