[v1loader]Reduce EB300B model loading time (#3700)

* speed up eb45

* update
This commit is contained in:
bukejiyu
2025-09-02 19:13:57 +08:00
committed by GitHub
parent 693c7d781c
commit b6a4115369
4 changed files with 45 additions and 36 deletions

View File

@@ -415,6 +415,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
output_dim = getattr(param, "output_dim", None)
assert output_dim is not None
@@ -446,7 +447,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()
param_shard_size = output_size // 2
@@ -548,6 +549,7 @@ class QKVParallelLinear(ColumnParallelLinear):
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
if loaded_shard_id is None:
# Loaded weight is already fused on disk
@@ -568,12 +570,13 @@ class QKVParallelLinear(ColumnParallelLinear):
# Tensor parallelism splits the weight along the output_dim
if self.nranks != 1:
block_size = self._get_shard_size_mapping(loaded_shard_id)
dim = -1 if output_dim else 0
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
shard_offset = shard_id * block_size
shard_size = (shard_id + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
loaded_weight = get_tensor(loaded_weight)
if not param._is_initialized():
param.initialize()

View File

@@ -176,17 +176,24 @@ class FusedMoE(nn.Layer):
if shard_id is None:
# 1.gate up fused in disk
model_format = getattr(param, "model_format", "")
is_torch_model = model_format == "torch"
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("gate", 0, output_size // 2 * self.tp_size),
("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = slice_fn(
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
per_rank = output_size // 2
start = self.tp_rank * per_rank
loaded_weight_shard_gate = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
)
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
loaded_weight_shard_up = slice_fn(
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
)
self._load_gate_up_weight(
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
)
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
else:
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
@@ -198,22 +205,23 @@ class FusedMoE(nn.Layer):
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
)
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
dim = -1 if shard_dim else 0
if self.tp_size > 1:
is_torch_model = model_format == "torch"
if self.tp_size > 1 and not is_sharded:
tp_shard_dim = is_torch_model ^ shard_dim
weight_dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
size = loaded_weight.shape[dim]
size = loaded_weight.shape[weight_dim]
else:
size = loaded_weight.get_shape()[dim]
size = loaded_weight.get_shape()[weight_dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
if shard_id == "gate":
param_shard_offset = 0
@@ -232,7 +240,6 @@ class FusedMoE(nn.Layer):
)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if current_platform.is_xpu() or current_platform.is_gcu():
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
@@ -242,24 +249,24 @@ class FusedMoE(nn.Layer):
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
is_torch_model = model_format == "torch"
if self.tp_size > 1 and shard_dim is not None:
dim = -1 if shard_dim else 0
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
tp_shard_dim = is_torch_model ^ shard_dim
dim = -1 if tp_shard_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if hasattr(param, "tensor_track"):
# for dyn quant
param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if current_platform.is_xpu or current_platform.is_gcu():
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU and opensource weight
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (

View File

@@ -29,7 +29,6 @@ from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
@@ -186,8 +185,7 @@ def fast_weights_iterator(safe_tensor_list: list[str]):
with fast_safe_open(st_file, framework="np") as f:
for name in f.keys():
param_slice = f.get_slice(name)
paddle_tensor = get_tensor(param_slice)
yield name, paddle_tensor
yield name, param_slice
def fastsafetensors_weights_iterator(

View File

@@ -160,6 +160,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
output_dim = getattr(param, "output_dim", None)
model_format = getattr(param, "model_format", "")
if model_format == "torch":
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: