[RL]Resolve shape mismatch problems in RL-related modules (#5032)

* RL fix

* update
This commit is contained in:
bukejiyu
2025-11-19 11:12:48 +08:00
committed by GitHub
parent 4694ed2a43
commit a82f25ea7b
12 changed files with 61 additions and 87 deletions

View File

@@ -131,16 +131,24 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
def process_weight_transpose(layer, weight_name):
weight = getattr(layer, weight_name)
if len(weight.shape) == 2:
weight_transpose = weight.transpose([1, 0])
weight_shape = weight.shape[::-1]
elif len(weight.shape) == 3:
weight_transpose = weight.transpose([0, 2, 1])
weight_shape = [weight.shape[0]] + list(weight.shape[1:][::-1])
weight_tmp = layer.create_parameter(
shape=weight_transpose.shape,
dtype=weight_transpose.dtype,
shape=weight_shape,
dtype=weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
is_bias=False,
)
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
return
if len(weight.shape) == 2:
weight_transpose = weight.transpose([1, 0])
elif len(weight.shape) == 3:
weight_transpose = weight.transpose([0, 2, 1])
weight_tmp.copy_(weight_transpose, False)
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
@@ -163,9 +171,16 @@ def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
model_sublayer = sublayers_dict[model_sublayer_name]
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
if fd_config.quant_config and not fd_config.quant_config.is_checkpoint_bf16:
# skip for offline quantization
return
if hasattr(model_sublayer, "quant_method"):
quant_method = getattr(model_sublayer, "quant_method", None)
unquant_moe_cls = type(get_moe_method())
unquant_moe_layer = get_moe_method()
if unquant_moe_layer is None:
unquant_moe_cls = object
else:
unquant_moe_cls = type(unquant_moe_layer)
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
# skip unquantized linear
return
@@ -225,18 +240,23 @@ def process_final_after_loading(model, fd_config: FDConfig):
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
for name, sublayer in model.named_sublayers():
quant_method = getattr(sublayer, "quant_method", None)
if quant_method is not None:
unquant_moe_cls = type(get_moe_method())
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
continue
if hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(sublayer)
if isinstance(sublayer, KVBatchLinear):
continue
quant_method = getattr(sublayer, "quant_method", None)
if quant_method is not None:
unquant_moe_layer = get_moe_method()
if unquant_moe_layer is None:
unquant_moe_cls = object
else:
unquant_moe_cls = type(unquant_moe_layer)
is_unquant_cls = type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls
is_offline_quantized_ckpt = not (fd_config.quant_config and fd_config.quant_config.is_checkpoint_bf16)
if is_unquant_cls or is_offline_quantized_ckpt:
if hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(sublayer)
continue
if not hasattr(sublayer, "process_weights_after_loading"):
continue
# Only for specific layers, such as lmhead
sublayer.process_weights_after_loading()