mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928) * support glm45_air * [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051) * check * fix v1 load for mix and wint8 * check --quantizations 'None' * check * support RL rollout * check v1 loader * check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl * check rollout moe gate begin layer_id * check rollout e_score_correction_bias * delete infer_to_train_mapping={} * code check
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
@@ -28,6 +29,10 @@ from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
|
||||
Ernie4_5_VLMoeForConditionalGeneration,
|
||||
Ernie4_5_VLPretrainedModel,
|
||||
)
|
||||
from fastdeploy.model_executor.models.glm4_moe import (
|
||||
Glm4MoeForCausalLM,
|
||||
Glm4MoePretrainedModel,
|
||||
)
|
||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
from fastdeploy.model_executor.models.qwen2 import (
|
||||
Qwen2ForCausalLM,
|
||||
@@ -529,3 +534,83 @@ class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, B
|
||||
self._complete_missing_mappings()
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
|
||||
class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel):
|
||||
"""
|
||||
Glm4MoeForCausalLMRL
|
||||
"""
|
||||
|
||||
_get_tensor_parallel_mappings = Glm4MoePretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Glm4MoeForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
"""name"""
|
||||
return "Glm4MoeForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
if self._mappings_built:
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
self.infer_to_train_mapping = {}
|
||||
self._mappings_built = True
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
self._update_base_mappings("model")
|
||||
|
||||
base_name = "model.layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx: int):
|
||||
# MoE specific mappings
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
|
||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||
)
|
||||
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = (
|
||||
f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"
|
||||
)
|
||||
|
||||
# MoE experts mappings
|
||||
for expert_idx in range(self.fd_config.model_config.n_routed_experts):
|
||||
for ph in place_holders:
|
||||
# up_gate_proj (up_gate_proj)
|
||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
|
||||
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# down_proj (down_proj)
|
||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
|
||||
if down_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[down_proj_key] = []
|
||||
self.infer_to_train_mapping[down_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
# Process MoE layers
|
||||
for layer_idx in range(
|
||||
self.fd_config.model_config.first_k_dense_replace,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
self._complete_missing_mappings()
|
||||
infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping)
|
||||
for key in infer_to_train_mapping_copy.keys():
|
||||
if "mlp.experts.gate_correction_bias" in key:
|
||||
self.infer_to_train_mapping.pop(key)
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
Reference in New Issue
Block a user