mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL]Resolve shape mismatch problems in RL-related modules (#5032)
* RL fix * update
This commit is contained in:
@@ -959,7 +959,7 @@ class KVBatchLinear(nn.Layer):
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.kv_b_proj = kv_b_proj
|
||||
|
||||
self.weight_dtype = self._helper.get_default_dtype()
|
||||
@@ -968,7 +968,8 @@ class KVBatchLinear(nn.Layer):
|
||||
self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight"
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
return
|
||||
w = self.kv_b_proj.weight.reshape(
|
||||
[
|
||||
self.kv_lora_rank,
|
||||
|
||||
@@ -26,12 +26,7 @@ from fastdeploy.model_executor.layers.utils import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
pad_vocab_size,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
free_tensor,
|
||||
set_weight_attrs,
|
||||
temporary_dtype,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs, temporary_dtype
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -80,7 +75,6 @@ class ParallelLMHead(nn.Layer):
|
||||
if num_embeddings % self.nranks != 0:
|
||||
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.model_format = fd_config.model_config.model_format
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
@@ -90,39 +84,21 @@ class ParallelLMHead(nn.Layer):
|
||||
self.need_gather = True
|
||||
|
||||
with temporary_dtype(self.dtype):
|
||||
if self.fd_config.load_config.load_choices == "default_v1" and (
|
||||
self.model_format == "torch" or self.tie_word_embeddings
|
||||
):
|
||||
self.linear = RowParallelLinear(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
elif self.column_cut:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=self.need_gather,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||
},
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
@@ -139,34 +115,11 @@ class ParallelLMHead(nn.Layer):
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||
},
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if not (
|
||||
self.fd_config.load_config.load_choices == "default_v1"
|
||||
and (self.model_format == "torch" or self.tie_word_embeddings)
|
||||
):
|
||||
return
|
||||
if not self.linear.weight._is_initialized():
|
||||
self.linear.weight.initialize()
|
||||
weight_transpose = self.linear.weight.transpose([1, 0])
|
||||
with temporary_dtype(self.dtype):
|
||||
linear = fleet.meta_parallel.ColumnParallelLinear(
|
||||
self.embedding_dim,
|
||||
self.num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=self.need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
linear.weight.set_value(weight_transpose)
|
||||
free_tensor(self.linear.weight)
|
||||
self.linear = linear
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Load the checkpoint state dictionary into the layer.
|
||||
|
||||
@@ -1422,13 +1422,6 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
if (
|
||||
not weight_fully_copied(getattr(layer, up_gate_proj_weight_name))
|
||||
or not weight_fully_copied(getattr(layer, down_proj_weight_name))
|
||||
or not weight_fully_copied(getattr(layer, up_gate_proj_scale_name))
|
||||
or not weight_fully_copied(getattr(layer, down_proj_scale_name))
|
||||
):
|
||||
return
|
||||
process_weight_transpose(layer, up_gate_proj_weight_name)
|
||||
process_weight_transpose(layer, down_proj_weight_name)
|
||||
process_weight_transpose(layer, up_gate_proj_scale_name)
|
||||
|
||||
@@ -63,7 +63,7 @@ def get_moe_method():
|
||||
)
|
||||
|
||||
return MetaxCutlassUnquantizedFusedMoEMethod(None)
|
||||
raise NotImplementedError
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
@@ -189,7 +189,9 @@ class FusedMoE(nn.Layer):
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# unquantized quant_method
|
||||
self.quant_method = get_moe_method()
|
||||
assert self.quant_method is not None, "self.quant_method should not be None"
|
||||
self.redundant_table_manger = redundant_table_manger
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
@@ -62,10 +62,13 @@ def load_weights_from_cache(model, weights_iterator):
|
||||
logger.info(f"{loaded_weight_name} is not in model parameters.")
|
||||
continue
|
||||
param = params_dict[loaded_weight_name]
|
||||
if param.shape != loaded_weight.shape:
|
||||
raise ValueError(
|
||||
f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
|
||||
model.lm_head.linear.weight.set_value(loaded_weight)
|
||||
model.lm_head.process_weights_after_loading()
|
||||
model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0]))
|
||||
for _, model_sublayer in model.named_sublayers():
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
@@ -107,7 +110,6 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
|
||||
|
||||
weight_cache_context = multi_switch_config_context(
|
||||
(fd_config.quant_config, "is_checkpoint_bf16", False),
|
||||
(fd_config.model_config, "model_format", "paddle"),
|
||||
)
|
||||
|
||||
return enable_cache, weight_cache_dir, weight_cache_context
|
||||
|
||||
@@ -56,8 +56,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
load_weights_from_cache(model, weights_iterator)
|
||||
else:
|
||||
model.load_weights(weights_iterator)
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, fd_config)
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, fd_config)
|
||||
|
||||
self.clean_memory_fragments()
|
||||
|
||||
@@ -76,6 +76,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
architectures = architectures + "RL"
|
||||
|
||||
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||
fd_config.model_config.enable_cache = enable_cache
|
||||
with weight_cache_context:
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
@@ -88,6 +89,8 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
assert_never(convert_type)
|
||||
|
||||
model = model_cls(fd_config)
|
||||
if fd_config.load_config.dynamic_load_weight or fd_config.model_config.enable_cache:
|
||||
process_final_after_loading(model, fd_config)
|
||||
|
||||
model.eval()
|
||||
# RL model not need set_state_dict
|
||||
|
||||
@@ -600,7 +600,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
@@ -718,7 +718,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
|
||||
@@ -377,7 +377,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight)
|
||||
self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
|
||||
@@ -231,7 +231,7 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight)
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
|
||||
|
||||
@@ -320,7 +320,7 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||
|
||||
if self.tie_word_embeddings and not is_pooling_model:
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight)
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict):
|
||||
|
||||
@@ -131,16 +131,24 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
|
||||
def process_weight_transpose(layer, weight_name):
|
||||
weight = getattr(layer, weight_name)
|
||||
if len(weight.shape) == 2:
|
||||
weight_transpose = weight.transpose([1, 0])
|
||||
weight_shape = weight.shape[::-1]
|
||||
elif len(weight.shape) == 3:
|
||||
weight_transpose = weight.transpose([0, 2, 1])
|
||||
|
||||
weight_shape = [weight.shape[0]] + list(weight.shape[1:][::-1])
|
||||
weight_tmp = layer.create_parameter(
|
||||
shape=weight_transpose.shape,
|
||||
dtype=weight_transpose.dtype,
|
||||
shape=weight_shape,
|
||||
dtype=weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
is_bias=False,
|
||||
)
|
||||
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
return
|
||||
|
||||
if len(weight.shape) == 2:
|
||||
weight_transpose = weight.transpose([1, 0])
|
||||
elif len(weight.shape) == 3:
|
||||
weight_transpose = weight.transpose([0, 2, 1])
|
||||
weight_tmp.copy_(weight_transpose, False)
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
@@ -163,9 +171,16 @@ def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
|
||||
model_sublayer = sublayers_dict[model_sublayer_name]
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
if fd_config.quant_config and not fd_config.quant_config.is_checkpoint_bf16:
|
||||
# skip for offline quantization
|
||||
return
|
||||
if hasattr(model_sublayer, "quant_method"):
|
||||
quant_method = getattr(model_sublayer, "quant_method", None)
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
unquant_moe_layer = get_moe_method()
|
||||
if unquant_moe_layer is None:
|
||||
unquant_moe_cls = object
|
||||
else:
|
||||
unquant_moe_cls = type(unquant_moe_layer)
|
||||
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
|
||||
# skip unquantized linear
|
||||
return
|
||||
@@ -225,18 +240,23 @@ def process_final_after_loading(model, fd_config: FDConfig):
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
|
||||
|
||||
for name, sublayer in model.named_sublayers():
|
||||
quant_method = getattr(sublayer, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
|
||||
continue
|
||||
if hasattr(quant_method, "process_weights_after_loading"):
|
||||
quant_method.process_weights_after_loading(sublayer)
|
||||
if isinstance(sublayer, KVBatchLinear):
|
||||
continue
|
||||
quant_method = getattr(sublayer, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
unquant_moe_layer = get_moe_method()
|
||||
if unquant_moe_layer is None:
|
||||
unquant_moe_cls = object
|
||||
else:
|
||||
unquant_moe_cls = type(unquant_moe_layer)
|
||||
is_unquant_cls = type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls
|
||||
is_offline_quantized_ckpt = not (fd_config.quant_config and fd_config.quant_config.is_checkpoint_bf16)
|
||||
if is_unquant_cls or is_offline_quantized_ckpt:
|
||||
if hasattr(quant_method, "process_weights_after_loading"):
|
||||
quant_method.process_weights_after_loading(sublayer)
|
||||
continue
|
||||
if not hasattr(sublayer, "process_weights_after_loading"):
|
||||
continue
|
||||
# Only for specific layers, such as lmhead
|
||||
sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user