mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[V1 Loader] Support Ernie text(moe and dense) (#3110)
* new loader support 0.3B * fix weight * support parallel load * support parallel load * fix slice * support moe * delete code * perfect code * perfect code
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -392,6 +393,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk.
|
||||
if self.nranks != 1:
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, self.output_size * self.nranks // 2),
|
||||
("up", self.output_size * self.nranks // 2, self.output_size * self.nranks // 2),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
else:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
param.copy_(loaded_weight, False)
|
||||
else:
|
||||
# 1.fused gate_up in disk
|
||||
# 2.split gate up
|
||||
assert loaded_shard_id in ["gate", "up"]
|
||||
@@ -399,6 +415,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None:
|
||||
dim = -1
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
@@ -486,6 +505,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk
|
||||
if self.nranks != 1:
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.num_heads * self.head_dim),
|
||||
("k", self.num_heads * self.head_dim, self.kv_num_heads * self.head_dim),
|
||||
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
else:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
split_loaded_weight = loaded_weight
|
||||
param.copy_(split_loaded_weight, False)
|
||||
else:
|
||||
# 1.fused qkv in disk
|
||||
# 2.split q k v
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
@@ -493,6 +529,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None:
|
||||
dim = -1
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
|
@@ -203,3 +203,10 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
|
||||
set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs)
|
||||
set_weight_attrs(layer.down_proj_weight, extra_weight_attrs)
|
||||
|
||||
if layer.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_shape = [1, layer.num_experts]
|
||||
layer.gate_correction_bias = layer.create_parameter(
|
||||
shape=gate_correction_bias_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -110,13 +111,18 @@ class FusedMoE(nn.Layer):
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
|
||||
# used for deepseek_v3
|
||||
self.topk_method = topk_method
|
||||
self.topk_group = topk_group
|
||||
@@ -175,16 +181,29 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
return
|
||||
if self.tp_size > 1:
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, self.moe_intermediate_size * self.tp_size),
|
||||
("up", self.moe_intermediate_size * self.tp_size, self.moe_intermediate_size * self.tp_size),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||
else:
|
||||
expert_param = param[expert_id]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
else:
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
expert_param = param[expert_id]
|
||||
if current_platform.is_cuda():
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
|
||||
else:
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||
self._load_expert_weight(
|
||||
expert_param=expert_param,
|
||||
param=param,
|
||||
expert_id=expert_id,
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=shard_id,
|
||||
@@ -198,6 +217,9 @@ class FusedMoE(nn.Layer):
|
||||
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
|
||||
|
||||
if self.tp_size > 1:
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[-1]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[-1]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
@@ -215,6 +237,9 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||
if self.tp_size > 1:
|
||||
if isinstance(loaded_weight, np.ndarray):
|
||||
size = loaded_weight.shape[shard_dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[shard_dim]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
@@ -231,11 +256,13 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def _load_expert_weight(
|
||||
self,
|
||||
expert_param,
|
||||
param,
|
||||
expert_id,
|
||||
shard_dim,
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
):
|
||||
expert_param = param[expert_id]
|
||||
if shard_id == "down":
|
||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
elif shard_id in ["gate", "up"]:
|
||||
@@ -244,29 +271,32 @@ class FusedMoE(nn.Layer):
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
param_gate_up_proj_name: str,
|
||||
param_down_proj_name: str,
|
||||
num_experts: int,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
ckpt_gate_proj_name: Optional[str] = None,
|
||||
ckpt_up_proj_name: Optional[str] = None,
|
||||
ckpt_down_proj_name: Optional[str] = None,
|
||||
ckpt_gate_up_proj_name: Optional[str] = None,
|
||||
param_gate_up_proj_name: Optional[str] = None,
|
||||
param_down_proj_name: Optional[str] = None,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
param_name_maping = [
|
||||
("gate", ckpt_gate_proj_name),
|
||||
("down", ckpt_down_proj_name),
|
||||
("up", ckpt_up_proj_name),
|
||||
]
|
||||
param_name_maping = []
|
||||
|
||||
if ckpt_gate_up_proj_name:
|
||||
param_name_maping.append((None, ckpt_gate_up_proj_name))
|
||||
if ckpt_gate_proj_name:
|
||||
param_name_maping.append(("gate", ckpt_gate_proj_name))
|
||||
if ckpt_down_proj_name:
|
||||
param_name_maping.append(("down", ckpt_down_proj_name))
|
||||
if ckpt_up_proj_name:
|
||||
param_name_maping.append(("up", ckpt_up_proj_name))
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
param_gate_up_proj_name
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name, ckpt_gate_up_proj_name]
|
||||
else param_down_proj_name
|
||||
),
|
||||
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
|
||||
@@ -505,11 +535,6 @@ class FusedMoE(nn.Layer):
|
||||
load_state_dict function.
|
||||
"""
|
||||
if not is_rearrange:
|
||||
self.gate_correction_bias_key = self.weight_key_map.get("gate_correction_bias_key", None)
|
||||
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
|
||||
self.moe_use_gate_correction_bias = True
|
||||
else:
|
||||
self.moe_use_gate_correction_bias = False
|
||||
if self.moe_use_gate_correction_bias:
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
|
@@ -647,12 +647,12 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
]
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.n_routed_experts,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
num_experts=self.fd_config.model_config.n_routed_experts,
|
||||
)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -431,6 +432,67 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
else:
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
"""
|
||||
Load model parameters from a given weights_iterator object.
|
||||
|
||||
Args:
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||
|
||||
general_params_mapping = [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||
("lm_head.linear", "lm_head", None, None),
|
||||
]
|
||||
|
||||
expert_params_mapping = []
|
||||
if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None:
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.moe_num_experts,
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_name="up_gate_proj",
|
||||
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias")
|
||||
)
|
||||
logger.info(f"expert params mapping:{expert_params_mapping}")
|
||||
all_param_mapping = general_params_mapping + expert_params_mapping
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
expert_id = None
|
||||
shard_id = None
|
||||
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
param = params_dict[model_param_name]
|
||||
expert_id = exp_id
|
||||
shard_id = shard_id
|
||||
break
|
||||
else:
|
||||
if loaded_weight_name not in params_dict.keys():
|
||||
continue
|
||||
param = params_dict[loaded_weight_name]
|
||||
|
||||
# Get weight loader from parameter and set weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
sig = inspect.signature(weight_loader)
|
||||
if "expert_id" in sig.parameters:
|
||||
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = paddle.cast(logits, paddle.float32)
|
||||
|
@@ -317,12 +317,12 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
num_experts=self.fd_config.model_config.num_experts,
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
num_experts=self.fd_config.model_config.num_experts,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
|
Reference in New Issue
Block a user