[vl]remove duplicated load logic (#2744)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-07-13 07:36:26 +08:00
committed by GitHub
parent 16940822a7
commit bad53c6b6e
11 changed files with 510 additions and 632 deletions

View File

@@ -28,9 +28,7 @@ from typing import NamedTuple, Optional
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.common_ops_import import convert_dtype
from paddle.distributed import fleet
from paddleformers.transformers.model_utils import _add_variant
from paddleformers.transformers.utils import paddleformers_load
from paddleformers.utils.env import (PADDLE_WEIGHTS_INDEX_NAME,
@@ -50,7 +48,8 @@ class LayerIdPlaceholder(str, enum.Enum):
FFN_LAYER_ID = "ffn_layer_id"
MOE_LAYER_ID = "moe_layer_id"
EXPERT_ID = "export_id"
TEXT_EXPERT_ID = "text_export_id"
IMG_EXPERT_ID = "img_export_id"
class WeightMeta(NamedTuple):
"""
@@ -272,31 +271,6 @@ def load_prefix_weights(
return past_key_values
def init_distributed_env() -> tuple[int, int]:
"""init distributed envs, and only support mp in ErnieBotModel
Returns:
tuple[int, int]: tensor_parallel_degree, tensor_parallel_rank
"""
tensor_parallel_degree = dist.get_world_size()
tensor_parallel_rank = 0
if tensor_parallel_degree > 1:
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": tensor_parallel_degree,
"pp_degree": 1,
"sharding_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
tensor_parallel_rank = hcg.get_model_parallel_rank()
return tensor_parallel_degree, tensor_parallel_rank
def w4a8_weight_convert(state_dict):
"""W4A8 权重转换函数
Args: