mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL]Resolve shape mismatch problems in RL-related modules (#5032)
* RL fix * update
This commit is contained in:
@@ -131,16 +131,24 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
|
||||
def process_weight_transpose(layer, weight_name):
|
||||
weight = getattr(layer, weight_name)
|
||||
if len(weight.shape) == 2:
|
||||
weight_transpose = weight.transpose([1, 0])
|
||||
weight_shape = weight.shape[::-1]
|
||||
elif len(weight.shape) == 3:
|
||||
weight_transpose = weight.transpose([0, 2, 1])
|
||||
|
||||
weight_shape = [weight.shape[0]] + list(weight.shape[1:][::-1])
|
||||
weight_tmp = layer.create_parameter(
|
||||
shape=weight_transpose.shape,
|
||||
dtype=weight_transpose.dtype,
|
||||
shape=weight_shape,
|
||||
dtype=weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
is_bias=False,
|
||||
)
|
||||
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
return
|
||||
|
||||
if len(weight.shape) == 2:
|
||||
weight_transpose = weight.transpose([1, 0])
|
||||
elif len(weight.shape) == 3:
|
||||
weight_transpose = weight.transpose([0, 2, 1])
|
||||
weight_tmp.copy_(weight_transpose, False)
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
@@ -163,9 +171,16 @@ def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
|
||||
model_sublayer = sublayers_dict[model_sublayer_name]
|
||||
if isinstance(model_sublayer, KVBatchLinear):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
if fd_config.quant_config and not fd_config.quant_config.is_checkpoint_bf16:
|
||||
# skip for offline quantization
|
||||
return
|
||||
if hasattr(model_sublayer, "quant_method"):
|
||||
quant_method = getattr(model_sublayer, "quant_method", None)
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
unquant_moe_layer = get_moe_method()
|
||||
if unquant_moe_layer is None:
|
||||
unquant_moe_cls = object
|
||||
else:
|
||||
unquant_moe_cls = type(unquant_moe_layer)
|
||||
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
|
||||
# skip unquantized linear
|
||||
return
|
||||
@@ -225,18 +240,23 @@ def process_final_after_loading(model, fd_config: FDConfig):
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
|
||||
|
||||
for name, sublayer in model.named_sublayers():
|
||||
quant_method = getattr(sublayer, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
|
||||
continue
|
||||
if hasattr(quant_method, "process_weights_after_loading"):
|
||||
quant_method.process_weights_after_loading(sublayer)
|
||||
if isinstance(sublayer, KVBatchLinear):
|
||||
continue
|
||||
quant_method = getattr(sublayer, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
unquant_moe_layer = get_moe_method()
|
||||
if unquant_moe_layer is None:
|
||||
unquant_moe_cls = object
|
||||
else:
|
||||
unquant_moe_cls = type(unquant_moe_layer)
|
||||
is_unquant_cls = type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls
|
||||
is_offline_quantized_ckpt = not (fd_config.quant_config and fd_config.quant_config.is_checkpoint_bf16)
|
||||
if is_unquant_cls or is_offline_quantized_ckpt:
|
||||
if hasattr(quant_method, "process_weights_after_loading"):
|
||||
quant_method.process_weights_after_loading(sublayer)
|
||||
continue
|
||||
if not hasattr(sublayer, "process_weights_after_loading"):
|
||||
continue
|
||||
# Only for specific layers, such as lmhead
|
||||
sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user