[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions

View File

@@ -959,7 +959,7 @@ class KVBatchLinear(nn.Layer):
# Split num_attention_heads when using TP inference.
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
self.fd_config = fd_config
self.kv_b_proj = kv_b_proj
self.weight_dtype = self._helper.get_default_dtype()
@@ -968,7 +968,8 @@ class KVBatchLinear(nn.Layer):
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
def process_weights_after_loading(self):
if self.fd_config.load_config.dynamic_load_weight:
return
w = self.kv_b_proj.weight.reshape(
[
self.kv_lora_rank,

View File

@@ -26,12 +26,7 @@ from fastdeploy.model_executor.layers.utils import (
DEFAULT_VOCAB_PADDING_SIZE,
pad_vocab_size,
)
from fastdeploy.model_executor.utils import (
default_weight_loader,
free_tensor,
set_weight_attrs,
temporary_dtype,
)
from fastdeploy.model_executor.utils import set_weight_attrs, temporary_dtype
from .utils import get_tensor
@@ -80,7 +75,6 @@ class ParallelLMHead(nn.Layer):
if num_embeddings % self.nranks != 0:
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
self.num_embeddings = num_embeddings
self.model_format = fd_config.model_config.model_format
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -90,39 +84,21 @@ class ParallelLMHead(nn.Layer):
self.need_gather = True
with temporary_dtype(self.dtype):
if self.fd_config.load_config.load_choices == "default_v1" and (
self.model_format == "torch" or self.tie_word_embeddings
):
self.linear = RowParallelLinear(
num_embeddings,
embedding_dim,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
},
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
elif self.column_cut:
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=self.need_gather,
gather_output=need_gather,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
},
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
@@ -139,34 +115,11 @@ class ParallelLMHead(nn.Layer):
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
},
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
def process_weights_after_loading(self):
if not (
self.fd_config.load_config.load_choices == "default_v1"
and (self.model_format == "torch" or self.tie_word_embeddings)
):
return
if not self.linear.weight._is_initialized():
self.linear.weight.initialize()
weight_transpose = self.linear.weight.transpose([1, 0])
with temporary_dtype(self.dtype):
linear = fleet.meta_parallel.ColumnParallelLinear(
self.embedding_dim,
self.num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=self.need_gather,
fuse_matmul_bias=False,
)
linear.weight.set_value(weight_transpose)
free_tensor(self.linear.weight)
self.linear = linear
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.

View File

@@ -1422,13 +1422,6 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
if (
not weight_fully_copied(getattr(layer, up_gate_proj_weight_name))
or not weight_fully_copied(getattr(layer, down_proj_weight_name))
or not weight_fully_copied(getattr(layer, up_gate_proj_scale_name))
or not weight_fully_copied(getattr(layer, down_proj_scale_name))
):
return
process_weight_transpose(layer, up_gate_proj_weight_name)
process_weight_transpose(layer, down_proj_weight_name)
process_weight_transpose(layer, up_gate_proj_scale_name)

View File

@@ -63,7 +63,7 @@ def get_moe_method():
)
return MetaxCutlassUnquantizedFusedMoEMethod(None)
raise NotImplementedError
return None
def get_moe_scores(
@@ -189,7 +189,9 @@ class FusedMoE(nn.Layer):
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
else:
# unquantized quant_method
self.quant_method = get_moe_method()
assert self.quant_method is not None, "self.quant_method should not be None"
self.redundant_table_manger = redundant_table_manger
if self.ep_size > 1:
self.quant_method.init_ep(self)

View File

@@ -62,10 +62,13 @@ def load_weights_from_cache(model, weights_iterator):
logger.info(f"{loaded_weight_name} is not in model parameters.")
continue
param = params_dict[loaded_weight_name]
if param.shape != loaded_weight.shape:
raise ValueError(
f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}"
)
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.linear.weight.set_value(loaded_weight)
model.lm_head.process_weights_after_loading()
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
@@ -107,7 +110,6 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = multi_switch_config_context(
(fd_config.quant_config, "is_checkpoint_bf16", False),
(fd_config.model_config, "model_format", "paddle"),
)
return enable_cache, weight_cache_dir, weight_cache_context

View File

@@ -56,8 +56,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
load_weights_from_cache(model, weights_iterator)
else:
model.load_weights(weights_iterator)
if fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, fd_config)
if fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, fd_config)
self.clean_memory_fragments()
@@ -76,6 +76,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
architectures = architectures + "RL"
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
fd_config.model_config.enable_cache = enable_cache
with weight_cache_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
@@ -88,6 +89,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
assert_never(convert_type)
model = model_cls(fd_config)
if fd_config.load_config.dynamic_load_weight or fd_config.model_config.enable_cache:
process_final_after_loading(model, fd_config)
model.eval()
# RL model not need set_state_dict

View File

@@ -600,7 +600,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)

View File

@@ -718,7 +718,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):

View File

@@ -377,7 +377,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]))
@classmethod
def name(self):

View File

@@ -231,7 +231,7 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):

View File

@@ -320,7 +320,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings and not is_pooling_model:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight)
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
@paddle.no_grad()
def set_state_dict(self, state_dict):

View File

@@ -131,16 +131,24 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
def process_weight_transpose(layer, weight_name):
weight = getattr(layer, weight_name)
if len(weight.shape) == 2:
weight_transpose = weight.transpose([1, 0])
weight_shape = weight.shape[::-1]
elif len(weight.shape) == 3:
weight_transpose = weight.transpose([0, 2, 1])
weight_shape = [weight.shape[0]] + list(weight.shape[1:][::-1])
weight_tmp = layer.create_parameter(
shape=weight_transpose.shape,
dtype=weight_transpose.dtype,
shape=weight_shape,
dtype=weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
is_bias=False,
)
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
return
if len(weight.shape) == 2:
weight_transpose = weight.transpose([1, 0])
elif len(weight.shape) == 3:
weight_transpose = weight.transpose([0, 2, 1])
weight_tmp.copy_(weight_transpose, False)
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
@@ -163,9 +171,16 @@ def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
model_sublayer = sublayers_dict[model_sublayer_name]
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
if fd_config.quant_config and not fd_config.quant_config.is_checkpoint_bf16:
# skip for offline quantization
return
if hasattr(model_sublayer, "quant_method"):
quant_method = getattr(model_sublayer, "quant_method", None)
unquant_moe_cls = type(get_moe_method())
unquant_moe_layer = get_moe_method()
if unquant_moe_layer is None:
unquant_moe_cls = object
else:
unquant_moe_cls = type(unquant_moe_layer)
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
# skip unquantized linear
return
@@ -225,18 +240,23 @@ def process_final_after_loading(model, fd_config: FDConfig):
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
for name, sublayer in model.named_sublayers():
quant_method = getattr(sublayer, "quant_method", None)
if quant_method is not None:
unquant_moe_cls = type(get_moe_method())
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
continue
if hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(sublayer)
if isinstance(sublayer, KVBatchLinear):
continue
quant_method = getattr(sublayer, "quant_method", None)
if quant_method is not None:
unquant_moe_layer = get_moe_method()
if unquant_moe_layer is None:
unquant_moe_cls = object
else:
unquant_moe_cls = type(unquant_moe_layer)
is_unquant_cls = type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls
is_offline_quantized_ckpt = not (fd_config.quant_config and fd_config.quant_config.is_checkpoint_bf16)
if is_unquant_cls or is_offline_quantized_ckpt:
if hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(sublayer)
continue
if not hasattr(sublayer, "process_weights_after_loading"):
continue
# Only for specific layers, such as lmhead
sublayer.process_weights_after_loading()