mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[V1 Loader] support weight_only (#3413)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* support wint4/wint8 * delete smoe case * update ci * print log
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -149,15 +150,6 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||
}
|
||||
|
||||
self.experts = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.moe_num_experts,
|
||||
top_k=fd_config.model_config.moe_k,
|
||||
layer_idx=layer_id,
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
@@ -168,6 +160,25 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
weight_dtype="float32",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.moe_num_experts,
|
||||
top_k=fd_config.model_config.moe_k,
|
||||
layer_idx=layer_id,
|
||||
gate_correction_bias=None,
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
|
||||
if fd_config.model_config.moe_use_aux_free:
|
||||
self.experts.gate_correction_bias = self.create_parameter(
|
||||
shape=[1, fd_config.model_config.moe_num_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
else:
|
||||
self.experts.gate_correction_bias = None
|
||||
|
||||
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
||||
if self.num_shared_experts > 0:
|
||||
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
||||
@@ -180,6 +191,13 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
def load_state_dict(self, state_dict):
|
||||
self.gate.load_state_dict(state_dict)
|
||||
self.experts.load_state_dict(state_dict)
|
||||
if self.experts.gate_correction_bias is not None:
|
||||
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
|
||||
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
|
||||
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
|
||||
self.experts.gate_correction_bias.shape
|
||||
)
|
||||
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
if self.num_shared_experts > 0:
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
@@ -441,12 +459,16 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
|
||||
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
|
||||
general_params_mapping = [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||
("lm_head.linear", "lm_head", None, None),
|
||||
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
|
||||
]
|
||||
|
||||
expert_params_mapping = []
|
||||
@@ -458,13 +480,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, "gate_bias")
|
||||
)
|
||||
logger.info(f"expert params mapping:{expert_params_mapping}")
|
||||
all_param_mapping = general_params_mapping + expert_params_mapping
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||
expert_id = None
|
||||
shard_id = None
|
||||
|
||||
@@ -478,9 +497,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
shard_id = shard_id
|
||||
break
|
||||
else:
|
||||
if loaded_weight_name not in params_dict.keys():
|
||||
model_param_name = loaded_weight_name
|
||||
if model_param_name not in params_dict.keys():
|
||||
continue
|
||||
param = params_dict[loaded_weight_name]
|
||||
param = params_dict[model_param_name]
|
||||
|
||||
# Get weight loader from parameter and set weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
@@ -490,6 +510,8 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user