mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
【Sync develop】support vl model name_mapping and ori_vocab_size (#2915)
* support vl ori_vacab_size * support trainer_degree in name_mapping * fix
This commit is contained in:
@@ -37,6 +37,25 @@ class MoEPhase(Enum):
|
|||||||
PREFILL = 1
|
PREFILL = 1
|
||||||
DECODER = 2
|
DECODER = 2
|
||||||
|
|
||||||
|
class ErnieArchitectures:
|
||||||
|
"""Helper class for ERNIE architecture check."""
|
||||||
|
|
||||||
|
ARCHITECTURES = {
|
||||||
|
"Ernie4_5_ForCausalLM",
|
||||||
|
"Ernie4_5_MoeForCausalLM",
|
||||||
|
"Ernie4_5_VLMoeForConditionalGeneration"
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def contains_ernie_arch(cls, architectures):
|
||||||
|
"""Check if any ERNIE architecture is present in the given architectures."""
|
||||||
|
return any(arch in architectures for arch in cls.ARCHITECTURES)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_ernie_arch(cls, architecture):
|
||||||
|
"""Check if the given architecture is an ERNIE architecture."""
|
||||||
|
return architecture in cls.ARCHITECTURES
|
||||||
|
|
||||||
PRETRAINED_INIT_CONFIGURATION = {
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
"rope_theta" : 10000.0,
|
"rope_theta" : 10000.0,
|
||||||
"num_key_value_heads" : -1,
|
"num_key_value_heads" : -1,
|
||||||
@@ -108,9 +127,10 @@ class ModelConfig:
|
|||||||
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
|
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
|
||||||
|
|
||||||
self.ori_vocab_size = self.vocab_size
|
self.ori_vocab_size = self.vocab_size
|
||||||
if "Ernie4_5_ForCausalLM" in self.architectures or "Ernie4_5_MoeForCausalLM" in self.architectures:
|
if ErnieArchitectures.contains_ernie_arch(self.architectures):
|
||||||
self.ori_vocab_size = args["ori_vocab_size"]
|
self.ori_vocab_size = args["ori_vocab_size"]
|
||||||
|
|
||||||
|
|
||||||
class ParallelConfig:
|
class ParallelConfig:
|
||||||
"""Configuration for the distributed execution."""
|
"""Configuration for the distributed execution."""
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@@ -17,6 +17,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
from fastdeploy.engine.config import ModelConfig
|
from fastdeploy.engine.config import ModelConfig
|
||||||
from fastdeploy.reasoning import ReasoningParserManager
|
from fastdeploy.reasoning import ReasoningParserManager
|
||||||
|
from fastdeploy.config import ErnieArchitectures
|
||||||
|
|
||||||
|
|
||||||
class InputPreprocessor:
|
class InputPreprocessor:
|
||||||
@@ -71,8 +72,7 @@ class InputPreprocessor:
|
|||||||
self.reasoning_parser)
|
self.reasoning_parser)
|
||||||
architectures = ModelConfig(self.model_name_or_path).architectures
|
architectures = ModelConfig(self.model_name_or_path).architectures
|
||||||
if not self.enable_mm:
|
if not self.enable_mm:
|
||||||
if "Ernie4_5_MoeForCausalLM" not in architectures \
|
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||||
and "Ernie4_5_ForCausalLM" not in architectures:
|
|
||||||
from fastdeploy.input.text_processor import DataProcessor
|
from fastdeploy.input.text_processor import DataProcessor
|
||||||
self.processor = DataProcessor(
|
self.processor = DataProcessor(
|
||||||
model_name_or_path=self.model_name_or_path, reasoning_parser_obj=reasoning_parser_obj)
|
model_name_or_path=self.model_name_or_path, reasoning_parser_obj=reasoning_parser_obj)
|
||||||
|
@@ -17,7 +17,7 @@
|
|||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig, ErnieArchitectures
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
from fastdeploy.utils import llm_logger
|
from fastdeploy.utils import llm_logger
|
||||||
|
|
||||||
@@ -268,8 +268,7 @@ class BackendBase:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
architectures = self.fd_config.model_config.architectures
|
architectures = self.fd_config.model_config.architectures
|
||||||
if "Ernie4_5_MoeForCausalLM" not in architectures \
|
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||||
and "Ernie4_5_ForCausalLM" not in architectures:
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
@@ -161,7 +161,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
|
|
||||||
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.share_experts = Ernie4_5_VLMLP(
|
self.shared_experts = Ernie4_5_VLMLP(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
intermediate_size=self.num_shared_experts *
|
intermediate_size=self.num_shared_experts *
|
||||||
fd_config.model_config.moe_intermediate_size[0],
|
fd_config.model_config.moe_intermediate_size[0],
|
||||||
@@ -193,11 +193,11 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
if self.text_fused_moe.moe_use_gate_correction_bias:
|
if self.text_fused_moe.moe_use_gate_correction_bias:
|
||||||
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
|
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.share_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
share_experts_out = self.share_experts(hidden_states)
|
shared_experts_out = self.shared_experts(hidden_states)
|
||||||
if vl_moe_meta.image_input is not None:
|
if vl_moe_meta.image_input is not None:
|
||||||
text_image_gather_scatter(
|
text_image_gather_scatter(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -222,7 +222,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
hidden_states = self.text_fused_moe(hidden_states)
|
hidden_states = self.text_fused_moe(hidden_states)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
hidden_states += share_experts_out
|
hidden_states += shared_experts_out
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
tensor_model_parallel_all_reduce(hidden_states)
|
tensor_model_parallel_all_reduce(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
@@ -23,15 +23,14 @@ from paddleformers.utils.log import logger
|
|||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.model_loader import ModelRegistry
|
from fastdeploy.model_executor.model_loader import ModelRegistry
|
||||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||||
Ernie4_5_MoeForCausalLM
|
Ernie4_5_MoeForCausalLM, Ernie4_5_PretrainedModel
|
||||||
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import \
|
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import \
|
||||||
Ernie4_5_VLMoeForConditionalGeneration
|
Ernie4_5_VLMoeForConditionalGeneration, Ernie4_5_VLPretrainedModel
|
||||||
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2PretrainedModel
|
||||||
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM
|
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3PretrainedModel
|
||||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM
|
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM, Qwen3MoePretrainedModel
|
||||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||||
|
|
||||||
|
|
||||||
class RolloutModel(nn.Layer):
|
class RolloutModel(nn.Layer):
|
||||||
"""Main model class for rollout operations, supports multimodal components for train."""
|
"""Main model class for rollout operations, supports multimodal components for train."""
|
||||||
|
|
||||||
@@ -51,9 +50,13 @@ class RolloutModel(nn.Layer):
|
|||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
"""Get parameter name mappings between rollout and training models."""
|
"""Get parameter name mappings between rollout and training models."""
|
||||||
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})()
|
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
|
||||||
|
|
||||||
|
def get_quantization_infer_keys(self) -> Dict[str, str]:
|
||||||
|
"""Get parameter name mappings between rollout and training models."""
|
||||||
|
return getattr(self.rollout_model, "get_quantization_infer_keys", lambda: {})()
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
@@ -61,10 +64,51 @@ class RolloutModel(nn.Layer):
|
|||||||
return self.rollout_model.state_dict()
|
return self.rollout_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
class BaseRLModel(nn.Layer):
|
||||||
|
"""Base class for RL models with common functionality"""
|
||||||
|
def __init__(self,):
|
||||||
|
super(BaseRLModel, self).__init__()
|
||||||
|
self.infer_to_train_mapping = {}
|
||||||
|
self.fd_config = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(cls) -> str:
|
||||||
|
return cls.__name__
|
||||||
|
|
||||||
|
def _update_base_mappings(self, base_name: str) -> None:
|
||||||
|
"""Common static mappings"""
|
||||||
|
static_mappings = {
|
||||||
|
f"{base_name}.embed_tokens.embeddings.weight": f"{base_name}.embed_tokens.weight",
|
||||||
|
"lm_head.linear.weight": "lm_head.weight"
|
||||||
|
}
|
||||||
|
self.infer_to_train_mapping.update(static_mappings)
|
||||||
|
|
||||||
|
def _complete_missing_mappings(self) -> None:
|
||||||
|
"""
|
||||||
|
Complete the mapping dictionary with keys that have identical names in inference and training.
|
||||||
|
"""
|
||||||
|
for key in self.state_dict().keys():
|
||||||
|
if key not in self.infer_to_train_mapping and "_scale" not in key:
|
||||||
|
# Skip weight scale parameters in mapping. Train and infer have same key.
|
||||||
|
self.infer_to_train_mapping[key] = key
|
||||||
|
|
||||||
|
def get_quantization_infer_keys(self) -> list[str]:
|
||||||
|
"""Get quantization infer keys"""
|
||||||
|
quant_weight_key = []
|
||||||
|
if self.fd_config.quant_config.name() == "wint8":
|
||||||
|
""" RL only support weight_only_int8 now"""
|
||||||
|
for key in self.state_dict().keys():
|
||||||
|
if "scale" in key:
|
||||||
|
quant_weight_key.append(key.replace(".weight_scale", ".weight"))
|
||||||
|
else:
|
||||||
|
raise ValueError("Only 'wint8' quantization is supported in RL roullout.")
|
||||||
|
return quant_weight_key
|
||||||
|
|
||||||
|
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_MoeForCausalLMRL
|
Ernie4_5_MoeForCausalLMRL
|
||||||
"""
|
"""
|
||||||
|
_get_tensor_parallel_mappings = Ernie4_5_PretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
@@ -78,37 +122,23 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
|||||||
"""name"""
|
"""name"""
|
||||||
return "Ernie4_5_MoeForCausalLMRL"
|
return "Ernie4_5_MoeForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
self._update_base_mappings("ernie")
|
||||||
|
|
||||||
base_name = "ernie"
|
|
||||||
# Static mappings (non-layer specific)
|
|
||||||
static_mappings = {
|
|
||||||
f"{base_name}.embed_tokens.embeddings.weight":
|
|
||||||
f"{base_name}.embed_tokens.weight",
|
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
|
||||||
}
|
|
||||||
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
|
||||||
# Support tie_word_embeddings
|
|
||||||
logger.debug("enable tie_word_embeddings")
|
|
||||||
static_mappings.pop("lm_head.linear.weight")
|
|
||||||
infer_to_train.update(static_mappings)
|
|
||||||
|
|
||||||
base_name = base_name + ".layers"
|
|
||||||
|
|
||||||
|
base_name = "ernie.layers"
|
||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx: int):
|
def _add_layer_mappings(layer_idx: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||||
|
|
||||||
if self.fd_config.model_config.moe_use_aux_free:
|
if self.fd_config.model_config.moe_use_aux_free:
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
|
|
||||||
# MoE experts mappings
|
# MoE experts mappings
|
||||||
@@ -116,17 +146,17 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
|||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
# up_gate_proj (up_gate_proj)
|
# up_gate_proj (up_gate_proj)
|
||||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight"
|
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight"
|
||||||
if up_gate_proj_key not in infer_to_train:
|
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||||
infer_to_train[up_gate_proj_key] = []
|
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||||
infer_to_train[up_gate_proj_key].append(
|
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# down_proj (down_proj)
|
# down_proj (down_proj)
|
||||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight"
|
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight"
|
||||||
if down_proj_key not in infer_to_train:
|
if down_proj_key not in self.infer_to_train_mapping:
|
||||||
infer_to_train[down_proj_key] = []
|
self.infer_to_train_mapping[down_proj_key] = []
|
||||||
infer_to_train[down_proj_key].append(
|
self.infer_to_train_mapping[down_proj_key].append(
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -136,13 +166,16 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
|||||||
self.fd_config.model_config.num_hidden_layers):
|
self.fd_config.model_config.num_hidden_layers):
|
||||||
_add_layer_mappings(layer_idx)
|
_add_layer_mappings(layer_idx)
|
||||||
|
|
||||||
return infer_to_train
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration):
|
class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration, BaseRLModel):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_VLMoeForConditionalGenerationRL
|
Ernie4_5_VLMoeForConditionalGenerationRL
|
||||||
"""
|
"""
|
||||||
|
_get_tensor_parallel_mappings = Ernie4_5_VLPretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
@@ -156,64 +189,47 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
"""name"""
|
"""name"""
|
||||||
return "Ernie4_5_VLMoeForConditionalGenerationRL"
|
return "Ernie4_5_VLMoeForConditionalGenerationRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
self._update_base_mappings("ernie")
|
||||||
|
|
||||||
base_name = "ernie"
|
base_name = "ernie.layers"
|
||||||
# Static mappings (non-layer specific)
|
|
||||||
static_mappings = {
|
|
||||||
f"{base_name}.embed_tokens.embeddings.weight":
|
|
||||||
f"{base_name}.embed_tokens.weight",
|
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
|
||||||
}
|
|
||||||
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
|
||||||
# Support tie_word_embeddings
|
|
||||||
logger.debug("enable tie_word_embeddings")
|
|
||||||
static_mappings.pop("lm_head.linear.weight")
|
|
||||||
infer_to_train.update(static_mappings)
|
|
||||||
|
|
||||||
base_name = base_name + ".layers"
|
|
||||||
|
|
||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx: int, moe_tag: str):
|
def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = f"{base_name}.{layer_idx}.mlp.gate.weight" if moe_tag == "text" else f"{base_name}.{layer_idx}.mlp.gate.weight_1"
|
gate_suffix = "" if moe_tag == "text" else "_1"
|
||||||
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = \
|
||||||
|
f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}"
|
||||||
|
|
||||||
if self.fd_config.model_config.moe_use_aux_free:
|
if self.fd_config.model_config.moe_use_aux_free:
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
|
|
||||||
# MoE experts mappings
|
# Initialize defaultdict for expert weights
|
||||||
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
|
from collections import defaultdict
|
||||||
if moe_tag == "text":
|
from itertools import chain
|
||||||
expert_idx_start = 0
|
|
||||||
expert_idx_end = self.fd_config.model_config.moe_num_experts[0]
|
|
||||||
else:
|
|
||||||
expert_idx_start = self.fd_config.model_config.moe_num_experts[0]
|
|
||||||
expert_idx_end = self.fd_config.model_config.moe_num_experts[1]
|
|
||||||
|
|
||||||
for expert_idx in range(expert_idx_start, expert_idx_end):
|
def _generate_ranges(start, end, step=16, take=8):
|
||||||
|
"""生成 [start, start+take), [start+step, start+step+take), ... 直到 end"""
|
||||||
|
return chain(
|
||||||
|
*(range(i, min(i + take, end)) # 防止越界
|
||||||
|
for i in range(start, end, step)))
|
||||||
|
|
||||||
|
expert_mappings = defaultdict(list)
|
||||||
|
for expert_idx in _generate_ranges(expert_start, total_moe_num, expert_num_per_rank * 2, expert_num_per_rank):
|
||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
# up_gate_proj (up_gate_proj)
|
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append(
|
||||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"
|
|
||||||
if up_gate_proj_key not in infer_to_train:
|
|
||||||
infer_to_train[up_gate_proj_key] = []
|
|
||||||
infer_to_train[up_gate_proj_key].append(
|
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"].append(
|
||||||
# down_proj (down_proj)
|
|
||||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"
|
|
||||||
if down_proj_key not in infer_to_train:
|
|
||||||
infer_to_train[down_proj_key] = []
|
|
||||||
infer_to_train[down_proj_key].append(
|
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
self.infer_to_train_mapping.update(expert_mappings)
|
||||||
|
|
||||||
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
||||||
if isinstance(moe_layer_start_index, int):
|
if isinstance(moe_layer_start_index, int):
|
||||||
@@ -233,19 +249,28 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
|||||||
else:
|
else:
|
||||||
text_moe_layer_end_index = moe_layer_end_index[0]
|
text_moe_layer_end_index = moe_layer_end_index[0]
|
||||||
image_moe_layer_end_index = moe_layer_end_index[1]
|
image_moe_layer_end_index = moe_layer_end_index[1]
|
||||||
|
|
||||||
|
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
|
||||||
|
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
|
||||||
|
if not trainer_degree:
|
||||||
|
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
|
||||||
|
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
|
||||||
# Process MoE layers
|
# Process MoE layers
|
||||||
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
|
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
|
||||||
_add_layer_mappings(layer_idx, "text")
|
_add_expert_mappings(layer_idx, "text", expert_start=0)
|
||||||
for layer_idx in range(image_moe_layer_start_index, image_moe_layer_end_index):
|
for layer_idx in range(image_moe_layer_start_index, image_moe_layer_end_index):
|
||||||
_add_layer_mappings(layer_idx, "image")
|
_add_expert_mappings(layer_idx, "image", expert_start=expert_num_per_rank)
|
||||||
|
|
||||||
return infer_to_train
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel):
|
||||||
"""
|
"""
|
||||||
Qwen2ForCausalLMRL
|
Qwen2ForCausalLMRL
|
||||||
"""
|
"""
|
||||||
|
_get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
@@ -259,43 +284,35 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
|||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen2ForCausalLMRL"
|
return "Qwen2ForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
self._update_base_mappings("qwen2")
|
||||||
|
base_name = "qwen2.layers"
|
||||||
base_name = "qwen2"
|
|
||||||
# Static mappings (non-layer specific)
|
|
||||||
static_mappings = {
|
|
||||||
f"{base_name}.embed_tokens.embeddings.weight":
|
|
||||||
f"{base_name}.embed_tokens.weight",
|
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
|
||||||
}
|
|
||||||
infer_to_train.update(static_mappings)
|
|
||||||
|
|
||||||
base_name = base_name + ".layers"
|
|
||||||
|
|
||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx):
|
def _add_layer_mappings(layer_idx):
|
||||||
# FFN mappings
|
# FFN mappings
|
||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||||
|
|
||||||
for layer_idx in range(
|
for layer_idx in range(
|
||||||
self.fd_config.model_config.num_hidden_layers):
|
self.fd_config.model_config.num_hidden_layers):
|
||||||
_add_layer_mappings(layer_idx)
|
_add_layer_mappings(layer_idx)
|
||||||
|
|
||||||
return infer_to_train
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
|
||||||
"""
|
"""
|
||||||
Qwen3MoeForCausalLMRL
|
Qwen3MoeForCausalLMRL
|
||||||
"""
|
"""
|
||||||
|
_get_tensor_parallel_mappings = Qwen3MoePretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
@@ -309,33 +326,25 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
|||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen3MoeForCausalLMRL"
|
return "Qwen3MoeForCausalLMRL"
|
||||||
|
|
||||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
# Prepare placeholders
|
# Prepare placeholders
|
||||||
place_holders = ["weight"]
|
place_holders = ["weight"]
|
||||||
|
|
||||||
# Initialize mapping dictionary
|
# Initialize mapping dictionary
|
||||||
infer_to_train = {}
|
self._update_base_mappings("model")
|
||||||
|
self.infer_to_train_mapping = {}
|
||||||
|
|
||||||
base_name = "model"
|
base_name = "model.layers"
|
||||||
# Static mappings (non-layer specific)
|
|
||||||
static_mappings = {
|
|
||||||
f"{base_name}.embed_tokens.embeddings.weight":
|
|
||||||
f"{base_name}.embed_tokens.weight",
|
|
||||||
"lm_head.linear.weight": "lm_head.weight"
|
|
||||||
}
|
|
||||||
infer_to_train.update(static_mappings)
|
|
||||||
|
|
||||||
base_name = base_name + ".layers"
|
|
||||||
|
|
||||||
# Helper function to add layer mappings
|
# Helper function to add layer mappings
|
||||||
def _add_layer_mappings(layer_idx: int):
|
def _add_layer_mappings(layer_idx: int):
|
||||||
# MoE specific mappings
|
# MoE specific mappings
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||||
|
|
||||||
if self.fd_config.moe_config.moe_use_aux_free:
|
if self.fd_config.moe_config.moe_use_aux_free:
|
||||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||||
|
|
||||||
# MoE experts mappings
|
# MoE experts mappings
|
||||||
@@ -343,17 +352,17 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
|||||||
for ph in place_holders:
|
for ph in place_holders:
|
||||||
# up_gate_proj (up_gate_proj)
|
# up_gate_proj (up_gate_proj)
|
||||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight"
|
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight"
|
||||||
if up_gate_proj_key not in infer_to_train:
|
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||||
infer_to_train[up_gate_proj_key] = []
|
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||||
infer_to_train[up_gate_proj_key].append(
|
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# down_proj (down_proj)
|
# down_proj (down_proj)
|
||||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight"
|
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight"
|
||||||
if down_proj_key not in infer_to_train:
|
if down_proj_key not in self.infer_to_train_mapping:
|
||||||
infer_to_train[down_proj_key] = []
|
self.infer_to_train_mapping[down_proj_key] = []
|
||||||
infer_to_train[down_proj_key].append(
|
self.infer_to_train_mapping[down_proj_key].append(
|
||||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -361,13 +370,16 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
|||||||
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
|
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
|
||||||
_add_layer_mappings(layer_idx)
|
_add_layer_mappings(layer_idx)
|
||||||
|
|
||||||
return infer_to_train
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
|
||||||
"""
|
"""
|
||||||
Qwen3ForCausalLMRL
|
Qwen3ForCausalLMRL
|
||||||
"""
|
"""
|
||||||
|
_get_tensor_parallel_mappings = Qwen3PretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
@@ -380,3 +392,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""name"""
|
"""name"""
|
||||||
return "Qwen3ForCausalLMRL"
|
return "Qwen3ForCausalLMRL"
|
||||||
|
|
||||||
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
|
pass
|
@@ -25,7 +25,8 @@ import paddle.distributed.fleet as fleet
|
|||||||
|
|
||||||
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
||||||
GraphOptimizationConfig, LoadConfig,
|
GraphOptimizationConfig, LoadConfig,
|
||||||
ModelConfig, ParallelConfig, SpeculativeConfig)
|
ModelConfig, ParallelConfig, SpeculativeConfig,
|
||||||
|
ErnieArchitectures)
|
||||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
@@ -641,9 +642,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
quant_config_name = args.quantization
|
quant_config_name = args.quantization
|
||||||
quantization_config["quantization"] = quant_config_name
|
quantization_config["quantization"] = quant_config_name
|
||||||
# Special handling for Ernie models
|
# Special handling for Ernie models
|
||||||
is_ernie = "Ernie4_5_ForCausalLM" in model_config.architectures or \
|
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
|
||||||
"Ernie4_5_MoeForCausalLM" in model_config.architectures or \
|
|
||||||
"Ernie4_5_VLMoeForConditionalGeneration" in model_config.architectures
|
|
||||||
if quant_config_name == "wint4" and is_ernie:
|
if quant_config_name == "wint4" and is_ernie:
|
||||||
quantization_config["dense_quant_type"] = "wint8"
|
quantization_config["dense_quant_type"] = "wint8"
|
||||||
quantization_config["moe_quant_type"] = "wint4"
|
quantization_config["moe_quant_type"] = "wint4"
|
||||||
|
Reference in New Issue
Block a user