mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
【Sync develop】support vl model name_mapping and ori_vocab_size (#2915)
* support vl ori_vacab_size * support trainer_degree in name_mapping * fix
This commit is contained in:
@@ -23,15 +23,14 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.model_loader import ModelRegistry
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||
Ernie4_5_MoeForCausalLM
|
||||
Ernie4_5_MoeForCausalLM, Ernie4_5_PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import \
|
||||
Ernie4_5_VLMoeForConditionalGeneration
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM
|
||||
Ernie4_5_VLMoeForConditionalGeneration, Ernie4_5_VLPretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM, Qwen3MoePretrainedModel
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
|
||||
|
||||
class RolloutModel(nn.Layer):
|
||||
"""Main model class for rollout operations, supports multimodal components for train."""
|
||||
|
||||
@@ -51,9 +50,13 @@ class RolloutModel(nn.Layer):
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Get parameter name mappings between rollout and training models."""
|
||||
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})()
|
||||
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
|
||||
|
||||
def get_quantization_infer_keys(self) -> Dict[str, str]:
|
||||
"""Get parameter name mappings between rollout and training models."""
|
||||
return getattr(self.rollout_model, "get_quantization_infer_keys", lambda: {})()
|
||||
|
||||
@paddle.no_grad()
|
||||
def state_dict(self):
|
||||
@@ -61,10 +64,51 @@ class RolloutModel(nn.Layer):
|
||||
return self.rollout_model.state_dict()
|
||||
|
||||
|
||||
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
class BaseRLModel(nn.Layer):
|
||||
"""Base class for RL models with common functionality"""
|
||||
def __init__(self,):
|
||||
super(BaseRLModel, self).__init__()
|
||||
self.infer_to_train_mapping = {}
|
||||
self.fd_config = None
|
||||
|
||||
@classmethod
|
||||
def name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
def _update_base_mappings(self, base_name: str) -> None:
|
||||
"""Common static mappings"""
|
||||
static_mappings = {
|
||||
f"{base_name}.embed_tokens.embeddings.weight": f"{base_name}.embed_tokens.weight",
|
||||
"lm_head.linear.weight": "lm_head.weight"
|
||||
}
|
||||
self.infer_to_train_mapping.update(static_mappings)
|
||||
|
||||
def _complete_missing_mappings(self) -> None:
|
||||
"""
|
||||
Complete the mapping dictionary with keys that have identical names in inference and training.
|
||||
"""
|
||||
for key in self.state_dict().keys():
|
||||
if key not in self.infer_to_train_mapping and "_scale" not in key:
|
||||
# Skip weight scale parameters in mapping. Train and infer have same key.
|
||||
self.infer_to_train_mapping[key] = key
|
||||
|
||||
def get_quantization_infer_keys(self) -> list[str]:
|
||||
"""Get quantization infer keys"""
|
||||
quant_weight_key = []
|
||||
if self.fd_config.quant_config.name() == "wint8":
|
||||
""" RL only support weight_only_int8 now"""
|
||||
for key in self.state_dict().keys():
|
||||
if "scale" in key:
|
||||
quant_weight_key.append(key.replace(".weight_scale", ".weight"))
|
||||
else:
|
||||
raise ValueError("Only 'wint8' quantization is supported in RL roullout.")
|
||||
return quant_weight_key
|
||||
|
||||
class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
|
||||
"""
|
||||
Ernie4_5_MoeForCausalLMRL
|
||||
"""
|
||||
_get_tensor_parallel_mappings = Ernie4_5_PretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
@@ -78,37 +122,23 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
"""name"""
|
||||
return "Ernie4_5_MoeForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
base_name = "ernie"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{base_name}.embed_tokens.embeddings.weight":
|
||||
f"{base_name}.embed_tokens.weight",
|
||||
"lm_head.linear.weight": "lm_head.weight"
|
||||
}
|
||||
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
||||
# Support tie_word_embeddings
|
||||
logger.debug("enable tie_word_embeddings")
|
||||
static_mappings.pop("lm_head.linear.weight")
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
base_name = base_name + ".layers"
|
||||
self._update_base_mappings("ernie")
|
||||
|
||||
base_name = "ernie.layers"
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx: int):
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.model_config.moe_use_aux_free:
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# MoE experts mappings
|
||||
@@ -116,17 +146,17 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
for ph in place_holders:
|
||||
# up_gate_proj (up_gate_proj)
|
||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.up_gate_proj_weight"
|
||||
if up_gate_proj_key not in infer_to_train:
|
||||
infer_to_train[up_gate_proj_key] = []
|
||||
infer_to_train[up_gate_proj_key].append(
|
||||
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# down_proj (down_proj)
|
||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.fused_moe.down_proj_weight"
|
||||
if down_proj_key not in infer_to_train:
|
||||
infer_to_train[down_proj_key] = []
|
||||
infer_to_train[down_proj_key].append(
|
||||
if down_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[down_proj_key] = []
|
||||
self.infer_to_train_mapping[down_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
@@ -136,13 +166,16 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
self.fd_config.model_config.num_hidden_layers):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
return infer_to_train
|
||||
self._complete_missing_mappings()
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration):
|
||||
class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGeneration, BaseRLModel):
|
||||
"""
|
||||
Ernie4_5_VLMoeForConditionalGenerationRL
|
||||
"""
|
||||
_get_tensor_parallel_mappings = Ernie4_5_VLPretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
@@ -156,64 +189,47 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
||||
"""name"""
|
||||
return "Ernie4_5_VLMoeForConditionalGenerationRL"
|
||||
|
||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
self._update_base_mappings("ernie")
|
||||
|
||||
base_name = "ernie"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{base_name}.embed_tokens.embeddings.weight":
|
||||
f"{base_name}.embed_tokens.weight",
|
||||
"lm_head.linear.weight": "lm_head.weight"
|
||||
}
|
||||
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
|
||||
# Support tie_word_embeddings
|
||||
logger.debug("enable tie_word_embeddings")
|
||||
static_mappings.pop("lm_head.linear.weight")
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
base_name = base_name + ".layers"
|
||||
base_name = "ernie.layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx: int, moe_tag: str):
|
||||
def _add_expert_mappings(layer_idx: int, moe_tag: str, expert_start: int):
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = f"{base_name}.{layer_idx}.mlp.gate.weight" if moe_tag == "text" else f"{base_name}.{layer_idx}.mlp.gate.weight_1"
|
||||
gate_suffix = "" if moe_tag == "text" else "_1"
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_weight"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.gate.weight{gate_suffix}"
|
||||
|
||||
if self.fd_config.model_config.moe_use_aux_free:
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.gate_correction_bias"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# MoE experts mappings
|
||||
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
|
||||
if moe_tag == "text":
|
||||
expert_idx_start = 0
|
||||
expert_idx_end = self.fd_config.model_config.moe_num_experts[0]
|
||||
else:
|
||||
expert_idx_start = self.fd_config.model_config.moe_num_experts[0]
|
||||
expert_idx_end = self.fd_config.model_config.moe_num_experts[1]
|
||||
|
||||
for expert_idx in range(expert_idx_start, expert_idx_end):
|
||||
# Initialize defaultdict for expert weights
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
|
||||
def _generate_ranges(start, end, step=16, take=8):
|
||||
"""生成 [start, start+take), [start+step, start+step+take), ... 直到 end"""
|
||||
return chain(
|
||||
*(range(i, min(i + take, end)) # 防止越界
|
||||
for i in range(start, end, step)))
|
||||
|
||||
expert_mappings = defaultdict(list)
|
||||
for expert_idx in _generate_ranges(expert_start, total_moe_num, expert_num_per_rank * 2, expert_num_per_rank):
|
||||
for ph in place_holders:
|
||||
# up_gate_proj (up_gate_proj)
|
||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"
|
||||
if up_gate_proj_key not in infer_to_train:
|
||||
infer_to_train[up_gate_proj_key] = []
|
||||
infer_to_train[up_gate_proj_key].append(
|
||||
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# down_proj (down_proj)
|
||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"
|
||||
if down_proj_key not in infer_to_train:
|
||||
infer_to_train[down_proj_key] = []
|
||||
infer_to_train[down_proj_key].append(
|
||||
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.down_proj_weight"].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
self.infer_to_train_mapping.update(expert_mappings)
|
||||
|
||||
moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
||||
if isinstance(moe_layer_start_index, int):
|
||||
@@ -233,19 +249,28 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
|
||||
else:
|
||||
text_moe_layer_end_index = moe_layer_end_index[0]
|
||||
image_moe_layer_end_index = moe_layer_end_index[1]
|
||||
|
||||
assert isinstance(self.fd_config.model_config.moe_num_experts, list)
|
||||
total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
|
||||
if not trainer_degree:
|
||||
trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
|
||||
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
|
||||
# Process MoE layers
|
||||
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
|
||||
_add_layer_mappings(layer_idx, "text")
|
||||
_add_expert_mappings(layer_idx, "text", expert_start=0)
|
||||
for layer_idx in range(image_moe_layer_start_index, image_moe_layer_end_index):
|
||||
_add_layer_mappings(layer_idx, "image")
|
||||
_add_expert_mappings(layer_idx, "image", expert_start=expert_num_per_rank)
|
||||
|
||||
return infer_to_train
|
||||
self._complete_missing_mappings()
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
|
||||
class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
||||
class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel):
|
||||
"""
|
||||
Qwen2ForCausalLMRL
|
||||
"""
|
||||
_get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
@@ -259,43 +284,35 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
||||
"""name"""
|
||||
return "Qwen2ForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
base_name = "qwen2"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{base_name}.embed_tokens.embeddings.weight":
|
||||
f"{base_name}.embed_tokens.weight",
|
||||
"lm_head.linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
base_name = base_name + ".layers"
|
||||
|
||||
self._update_base_mappings("qwen2")
|
||||
base_name = "qwen2.layers"
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx):
|
||||
# FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||
|
||||
for layer_idx in range(
|
||||
self.fd_config.model_config.num_hidden_layers):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
return infer_to_train
|
||||
self._complete_missing_mappings()
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
|
||||
"""
|
||||
Qwen3MoeForCausalLMRL
|
||||
"""
|
||||
_get_tensor_parallel_mappings = Qwen3MoePretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
@@ -309,33 +326,25 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
"""name"""
|
||||
return "Qwen3MoeForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self) -> Dict[str, str]:
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
self._update_base_mappings("model")
|
||||
self.infer_to_train_mapping = {}
|
||||
|
||||
base_name = "model"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{base_name}.embed_tokens.embeddings.weight":
|
||||
f"{base_name}.embed_tokens.weight",
|
||||
"lm_head.linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
base_name = base_name + ".layers"
|
||||
base_name = "model.layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx: int):
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate_weight"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.moe_config.moe_use_aux_free:
|
||||
infer_to_train[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"{base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# MoE experts mappings
|
||||
@@ -343,17 +352,17 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
for ph in place_holders:
|
||||
# up_gate_proj (up_gate_proj)
|
||||
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.up_gate_proj_weight"
|
||||
if up_gate_proj_key not in infer_to_train:
|
||||
infer_to_train[up_gate_proj_key] = []
|
||||
infer_to_train[up_gate_proj_key].append(
|
||||
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# down_proj (down_proj)
|
||||
down_proj_key = f"{base_name}.{layer_idx}.mlp.down_proj_weight"
|
||||
if down_proj_key not in infer_to_train:
|
||||
infer_to_train[down_proj_key] = []
|
||||
infer_to_train[down_proj_key].append(
|
||||
if down_proj_key not in self.infer_to_train_mapping:
|
||||
self.infer_to_train_mapping[down_proj_key] = []
|
||||
self.infer_to_train_mapping[down_proj_key].append(
|
||||
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
@@ -361,13 +370,16 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
return infer_to_train
|
||||
self._complete_missing_mappings()
|
||||
|
||||
return self.infer_to_train_mapping
|
||||
|
||||
|
||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
|
||||
"""
|
||||
Qwen3ForCausalLMRL
|
||||
"""
|
||||
_get_tensor_parallel_mappings = Qwen3PretrainedModel._get_tensor_parallel_mappings
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
@@ -380,3 +392,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
||||
def name(self) -> str:
|
||||
"""name"""
|
||||
return "Qwen3ForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
pass
|
Reference in New Issue
Block a user