mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
【Fearture】support qwen2 some func (#2740)
* add rl qwen model support * fix * fix
This commit is contained in:
@@ -24,25 +24,18 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.model_loader import ModelRegistry
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||
Ernie4_5_MoeForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
|
||||
RL_MODEL_CLASSES = {
|
||||
"Ernie4_5_MoeForCausalLMRL": Ernie4_5_MoeForCausalLM,
|
||||
"Qwen2ForCausalLMRL": Qwen2PretrainedModel,
|
||||
"Qwen3ForCausalLMRL": Qwen3PretrainedModel,
|
||||
"Qwen3MoeForCausalLMRL": Qwen3MoePretrainedModel,
|
||||
}
|
||||
|
||||
|
||||
class RollOutModel(nn.Layer):
|
||||
class RolloutModel(nn.Layer):
|
||||
"""Main model class for rollout operations, supports multimodal components for train."""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, rollout_model_config: RolloutModelConfig):
|
||||
"""Initialize with FastDeploy configuration."""
|
||||
super(RollOutModel, self).__init__()
|
||||
self.fd_config = fd_config
|
||||
super(RolloutModel, self).__init__()
|
||||
self.fd_config = rollout_model_config.initialize()
|
||||
self._init_models()
|
||||
|
||||
def _init_models(self):
|
||||
@@ -90,9 +83,6 @@ class RollOutModel(nn.Layer):
|
||||
all_params = {}
|
||||
for model in self.rollout_models:
|
||||
for name, param in model.state_dict().items():
|
||||
logger.debug(
|
||||
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
|
||||
)
|
||||
all_params[name] = param
|
||||
return all_params
|
||||
|
||||
@@ -123,11 +113,13 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "ernie"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
"model.embeddings.word_embeddings.weight":
|
||||
"ernie.embed_tokens.weight",
|
||||
"model.norm.ln_weight": "ernie.norm.weight",
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
if self.fd_config.model_config.get("weight_sharing", False):
|
||||
@@ -135,53 +127,55 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
logger.debug("enable tie_word_embeddings")
|
||||
static_mappings.pop("lm_head.out_linear.weight")
|
||||
infer_to_train.update(static_mappings)
|
||||
infer_base_name = "model.hidden_layers"
|
||||
|
||||
infer_base_name = infer_base_name + ".hidden_layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx, is_moe_layer=False):
|
||||
# Handle special case for layer 0's input layernorm
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"ernie.layers.{layer_idx}.input_layernorm.{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
# Common attention mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
if not is_moe_layer:
|
||||
# Dense FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.up_gate_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.down_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
else:
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.gate.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.moe_config.moe_use_aux_free:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# Support shared experts
|
||||
if self.fd_config.model_config.get(
|
||||
"moe_num_shared_experts") > 0:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
|
||||
# MoE experts mappings
|
||||
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
||||
@@ -191,7 +185,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
if ffn1_key not in infer_to_train:
|
||||
infer_to_train[ffn1_key] = []
|
||||
infer_to_train[ffn1_key].append(
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# FFN2 (down_proj)
|
||||
@@ -199,7 +193,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
if ffn2_key not in infer_to_train:
|
||||
infer_to_train[ffn2_key] = []
|
||||
infer_to_train[ffn2_key].append(
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
# Process non-MoE layers
|
||||
@@ -213,3 +207,118 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
_add_layer_mappings(layer_idx, is_moe_layer=True)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
||||
"""
|
||||
Qwen2ForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen2ForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self):
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "qwen2"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
infer_base_name = infer_base_name + ".layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx):
|
||||
# Handle special case for layer 0's input layernorm and attn o_proj
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# qwen qkv proj need bias
|
||||
for ph in ["weight", "bias"]:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
# FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
|
||||
for layer_idx in range(
|
||||
self.fd_config.model_config.num_layers):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
"""
|
||||
Qwen3MoeForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen3MoeForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self):
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
pass
|
||||
|
||||
|
||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
||||
"""
|
||||
Qwen3ForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen3ForCausalLMRL"
|
Reference in New Issue
Block a user