mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -37,291 +37,13 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import \
|
||||
LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
||||
|
||||
from paddleformers.transformers.conversion_utils import \
|
||||
split_or_merge_func
|
||||
|
||||
fn = split_or_merge_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
)
|
||||
|
||||
def gqa_qkv_split_func(
|
||||
weight,
|
||||
tensor_parallel_degree,
|
||||
tensor_parallel_rank,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
):
|
||||
|
||||
def get_shape(tensor):
|
||||
return (tensor.get_shape()
|
||||
if hasattr(tensor, "get_shape") else tensor.shape)
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
shape = get_shape(tensor)
|
||||
if len(shape) == 1:
|
||||
return tensor[start:end]
|
||||
else:
|
||||
return tensor[..., start:end]
|
||||
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(weight, 0, q_end)
|
||||
k = slice_tensor(weight, q_end, k_end)
|
||||
v = slice_tensor(weight, k_end, v_end)
|
||||
|
||||
def split_tensor(tensor, degree):
|
||||
shape = get_shape(tensor)
|
||||
size = shape[-1]
|
||||
block_size = size // degree
|
||||
if hasattr(tensor, "get_shape"):
|
||||
return [
|
||||
slice_tensor(tensor, i * block_size,
|
||||
(i + 1) * block_size)
|
||||
for i in range(degree)
|
||||
]
|
||||
else:
|
||||
return np.split(tensor, degree, axis=-1)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
return [
|
||||
np.concatenate([q_i, k_i, v_i], axis=-1)
|
||||
for q_i, k_i, v_i in zip(q_list, k_list, v_list)
|
||||
]
|
||||
else:
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
def gqa_qkv_merge_func(weight_list, num_attention_heads,
|
||||
num_key_value_heads, head_dim):
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
def get_shape(tensor):
|
||||
return (tensor.get_shape()
|
||||
if hasattr(tensor, "get_shape") else tensor.shape)
|
||||
|
||||
def slice_tensor(tensor, start, end):
|
||||
if len(get_shape(tensor)) == 1:
|
||||
return tensor[start:end]
|
||||
else:
|
||||
return tensor[..., start:end]
|
||||
|
||||
q_list, k_list, v_list = [], [], []
|
||||
|
||||
for weight in weight_list:
|
||||
q_end = num_attention_heads * head_dim
|
||||
k_end = q_end + num_key_value_heads * head_dim
|
||||
v_end = k_end + num_key_value_heads * head_dim
|
||||
|
||||
q = slice_tensor(weight, 0, q_end)
|
||||
k = slice_tensor(weight, q_end, k_end)
|
||||
v = slice_tensor(weight, k_end, v_end)
|
||||
|
||||
q_list.append(q)
|
||||
k_list.append(k)
|
||||
v_list.append(v)
|
||||
|
||||
merged = q_list + k_list + v_list
|
||||
|
||||
if is_paddle_tensor:
|
||||
tensor = paddle.concat(merged, axis=-1)
|
||||
if tensor.place.is_gpu_place():
|
||||
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
|
||||
return tensor
|
||||
else:
|
||||
return np.concatenate(merged, axis=-1)
|
||||
|
||||
if (config.num_key_value_heads is not None
|
||||
and config.num_key_value_heads != config.num_attention_heads):
|
||||
if is_split:
|
||||
qkv_fn = partial(
|
||||
gqa_qkv_split_func,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
else:
|
||||
qkv_fn = partial(
|
||||
gqa_qkv_merge_func,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
)
|
||||
else:
|
||||
qkv_fn = partial(fn, is_column=True)
|
||||
|
||||
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
|
||||
moe_num_shared_experts,
|
||||
moe_layer_start_index):
|
||||
|
||||
final_actions = {}
|
||||
|
||||
base_model_prefix = "ernie"
|
||||
base_actions = {
|
||||
"lm_head.weight":
|
||||
partial(fn, is_column=True),
|
||||
# "eh_proj.weight": partial(fn, is_column=True),
|
||||
f"{base_model_prefix}.embed_tokens.weight":
|
||||
partial(fn, is_column=False),
|
||||
}
|
||||
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.weight"] = qkv_fn
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.qkv_proj.quant_weight"] = qkv_fn
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.o_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.self_attn.o_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.weight"] = (
|
||||
partial(fn, is_column=False))
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
for expert_idx in range(moe_num_experts):
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.experts.{expert_idx}.down_proj.quant_weight"] = partial(
|
||||
fn, is_column=False)
|
||||
|
||||
if moe_num_shared_experts > 0:
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=True, is_naive_2fuse=True)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.down_proj.weight"] = partial(
|
||||
fn, is_column=False)
|
||||
base_actions[
|
||||
f"{base_model_prefix}.layers.{moe_layer_start_index}"
|
||||
f".mlp.shared_experts.up_gate_proj.quant_weight"] = partial(
|
||||
fn, is_column=False, is_naive_2fuse=True)
|
||||
|
||||
for key, action in base_actions.items():
|
||||
if (f"{base_model_prefix}.layers.0.mlp.up_gate_proj.weight"
|
||||
in key or
|
||||
f"{base_model_prefix}.layers.0.mlp.up_gate_proj.quant_weight"
|
||||
in key
|
||||
or f"{base_model_prefix}.layers.0.mlp.down_proj.weight"
|
||||
in key or
|
||||
f"{base_model_prefix}.layers.0.mlp.down_proj.quant_weight"
|
||||
in key):
|
||||
for i in range(moe_layer_start_index):
|
||||
final_actions[key.replace("layers.0.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"layers.{moe_layer_start_index}.mlp.experts." in key:
|
||||
for i in range(moe_layer_start_index, num_layers):
|
||||
final_actions[key.replace(
|
||||
f"layers.{moe_layer_start_index}.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"layers.{moe_layer_start_index}.mlp.shared_experts." in key:
|
||||
for i in range(moe_layer_start_index, num_layers):
|
||||
final_actions[key.replace(
|
||||
f"layers.{moe_layer_start_index}.",
|
||||
f"layers.{i}.")] = action
|
||||
elif f"{base_model_prefix}.layers.0." in key:
|
||||
for i in range(num_layers):
|
||||
final_actions[key.replace("layers.0.",
|
||||
f"layers.{i}.")] = action
|
||||
final_actions[key] = action
|
||||
return final_actions
|
||||
|
||||
moe_num_experts = 0
|
||||
moe_num_shared_experts = 0
|
||||
if isinstance(config.moe_num_experts, list):
|
||||
moe_num_experts = sum(config.moe_num_experts)
|
||||
elif isinstance(config.moe_num_experts, int):
|
||||
moe_num_experts = config.moe_num_experts
|
||||
if hasattr(config, 'moe_num_shared_experts'):
|
||||
moe_num_shared_experts = config.moe_num_shared_experts
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(
|
||||
config.num_layers,
|
||||
moe_num_experts,
|
||||
moe_num_shared_experts,
|
||||
moe_layer_start_index,
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
class Ernie4_5_MLP(nn.Layer):
|
||||
|
||||
def __init__(
|
||||
@@ -329,6 +51,7 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
fd_config: FDConfig,
|
||||
intermediate_size: int,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
@@ -345,7 +68,7 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
self.down_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
input_size=(intermediate_size // self.nranks),
|
||||
input_size=intermediate_size,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
with_bias=False,
|
||||
)
|
||||
@@ -423,8 +146,8 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
f"{prefix}.experts.{{}}.down_proj.code_zp",
|
||||
}
|
||||
elif moe_quant_type == "tensor_wise_fp8" or (
|
||||
moe_quant_type == "block_wise_fp8" and
|
||||
fd_config.model_config.is_quantized):
|
||||
moe_quant_type == "block_wise_fp8"
|
||||
and fd_config.model_config.is_quantized):
|
||||
weight_key_map = {
|
||||
"gate_weight_key":
|
||||
f"{prefix}.gate.weight",
|
||||
@@ -492,8 +215,6 @@ class Ernie4_5_Attention(nn.Layer):
|
||||
prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
@@ -502,8 +223,8 @@ class Ernie4_5_Attention(nn.Layer):
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=(fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads // nranks),
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
self.attn = Attention(
|
||||
@@ -636,12 +357,12 @@ class Ernie4_5_Model(nn.Layer):
|
||||
params_dtype=paddle.get_default_dtype(),
|
||||
prefix=(f"{fd_config.model_config.prefix_name}.embed_tokens"))
|
||||
|
||||
self.hidden_layers = [
|
||||
self.hidden_layers = nn.LayerList([
|
||||
Ernie4_5_DecoderLayer(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
@@ -772,3 +493,134 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5_PretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_PretrainedModel
|
||||
"""
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
weight_infos = [
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
|
||||
True, tsm.GQA),
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False),
|
||||
WeightMeta(".embed_tokens.weight", False),
|
||||
WeightMeta("lm_head.weight", True),
|
||||
# quant tensorwise
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight",
|
||||
True, tsm.GQA),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight",
|
||||
False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight",
|
||||
True, tsm.PairFused),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight",
|
||||
False),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _get_tensor_parallel_mappings(cls, config: ModelConfig, is_split=True):
|
||||
"""
|
||||
get_tensor_parallel_mappings
|
||||
"""
|
||||
logger.info("erine inference model _get_tensor_parallel_mappings")
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
build_expanded_keys, has_prefix, split_or_merge_func_v1)
|
||||
|
||||
fn = split_or_merge_func_v1(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim)
|
||||
|
||||
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
prefix_name):
|
||||
base_actions = {}
|
||||
weight_infos = cls.weight_infos
|
||||
for (weight_name, is_column, extra) in weight_infos:
|
||||
params = {
|
||||
"is_column": is_column,
|
||||
**({
|
||||
extra.value: True
|
||||
} if extra else {})
|
||||
}
|
||||
|
||||
if "lm_head.weight" in weight_name:
|
||||
key = weight_name
|
||||
elif not has_prefix(prefix_name, weight_name):
|
||||
key = f"{prefix_name}{weight_name}"
|
||||
else:
|
||||
key = weight_name
|
||||
base_actions[key] = partial(fn, **params)
|
||||
final_actions = {}
|
||||
start_layer = (moe_layer_start_index
|
||||
if moe_layer_start_index > 0 else num_layers)
|
||||
final_actions = build_expanded_keys(
|
||||
num_layers,
|
||||
moe_num_experts,
|
||||
start_layer,
|
||||
base_actions,
|
||||
)
|
||||
return final_actions
|
||||
|
||||
moe_num_experts = 0
|
||||
if isinstance(config.moe_num_experts, list):
|
||||
moe_num_experts = sum(config.moe_num_experts)
|
||||
elif isinstance(config.moe_num_experts, int):
|
||||
moe_num_experts = config.moe_num_experts
|
||||
|
||||
moe_layer_start_index = -1
|
||||
if isinstance(config.moe_layer_start_index, list):
|
||||
moe_layer_start_index = min(config.moe_layer_start_index)
|
||||
elif isinstance(config.moe_layer_start_index, int):
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_layers,
|
||||
moe_num_experts,
|
||||
moe_layer_start_index,
|
||||
config.prefix_name)
|
||||
return mappings
|
||||
|
||||
Reference in New Issue
Block a user