refactor pt loading (#4532)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-11-11 21:30:39 +08:00
committed by GitHub
parent 4c911ecb74
commit b09ebb2813
35 changed files with 1094 additions and 797 deletions

View File

@@ -128,13 +128,35 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
return weight_or_paramter
def process_weights_after_loading(sublayers_dict: dict):
def process_weight_transpose(layer, weight_name):
weight = getattr(layer, weight_name)
if len(weight.shape) == 2:
weight_transpose = weight.transpose([1, 0])
elif len(weight.shape) == 3:
weight_transpose = weight.transpose([0, 2, 1])
weight_tmp = layer.create_parameter(
shape=weight_transpose.shape,
dtype=weight_transpose.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
is_bias=False,
)
weight_tmp.copy_(weight_transpose, False)
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
"""
process_weights_after_loading: e.g., handle extracted weights (quantization, reshaping, etc.)
process_weights_after_loading:
"""
def fn(model_sublayer_name: str, param=None):
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.layers.linear import (
KVBatchLinear,
UnquantizedLinearMethod,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
if model_sublayer_name not in sublayers_dict:
return
@@ -143,6 +165,10 @@ def process_weights_after_loading(sublayers_dict: dict):
model_sublayer.process_weights_after_loading()
if hasattr(model_sublayer, "quant_method"):
quant_method = getattr(model_sublayer, "quant_method", None)
unquant_moe_cls = type(get_moe_method())
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
# skip unquantized linear
return
if not hasattr(quant_method, "process_weights_after_loading"):
return
if param is not None and hasattr(param, "tensor_track") and param.tensor_track is None:
@@ -184,6 +210,36 @@ def process_weights_before_loading(
return fn
def weight_fully_copied(weight):
return (
hasattr(weight, "tensor_track") and weight.tensor_track is not None and weight.tensor_track.is_fully_copied()
)
def process_final_after_loading(model, fd_config: FDConfig):
# process_final_after_loading handles the post-loading process for cases other than dynamic quantization.
from fastdeploy.model_executor.layers.linear import (
KVBatchLinear,
UnquantizedLinearMethod,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
for name, sublayer in model.named_sublayers():
quant_method = getattr(sublayer, "quant_method", None)
if quant_method is not None:
unquant_moe_cls = type(get_moe_method())
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
continue
if hasattr(quant_method, "process_weights_after_loading"):
quant_method.process_weights_after_loading(sublayer)
if isinstance(sublayer, KVBatchLinear):
continue
if not hasattr(sublayer, "process_weights_after_loading"):
continue
# Only for specific layers, such as lmhead
sublayer.process_weights_after_loading()
def free_tensor(tensor):
if hasattr(tensor, "tensor_track"):
tensor.tensor_track = None
@@ -191,6 +247,15 @@ def free_tensor(tensor):
del tensor
def fd_cast(weight, param):
if weight.dtype != param.dtype:
if weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
weight = weight.view(param.dtype)
else:
weight = weight.cast(param.dtype)
return weight
def default_weight_loader(fd_config: FDConfig = None) -> None:
"""Default weight loader"""
@@ -200,7 +265,6 @@ def default_weight_loader(fd_config: FDConfig = None) -> None:
output_dim = getattr(param, "output_dim", None)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
@@ -214,20 +278,15 @@ def default_weight_loader(fd_config: FDConfig = None) -> None:
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
loaded_weight = fd_cast(loaded_weight, param)
if param.shape != loaded_weight.shape:
# for e_score_correction_bias
loaded_weight = loaded_weight.reshape(param.shape)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
h2d_copy(dst=param, src=loaded_weight)
return fn
@@ -255,10 +314,44 @@ def is_paddle_support_v1_loader():
return is_same
_support_new_h2d = None
def is_paddle_support_new_h2d():
import subprocess
import sys
global _support_new_h2d
if _support_new_h2d is not None:
return _support_new_h2d
code = """
import paddle
try:
dst = paddle.zeros([2, 4], dtype='bfloat16')
src = paddle.ones([2, 2], dtype='bfloat16', device='cpu')
dst = dst[..., :2]
dst.copy_(src)
print(1)
except:
print(0)
"""
result = subprocess.run([sys.executable, "-c", code], capture_output=True)
_support_new_h2d = result.stdout.strip() == b"1"
return _support_new_h2d
def h2d_copy(dst, src, blocking=True):
if not current_platform.is_cuda() or not is_paddle_support_new_h2d():
# For non-GPU devices, data is transferred to device (H2D) in advance.
src = get_tensor(src)
if not dst._is_initialized():
dst.initialize()
dst.copy_(src, blocking)
def v1_loader_support(fd_config):
_v1_no_support_archs = [
"Qwen2VLForConditionalGeneration",
]
_v1_no_support_archs = ["Qwen2VLForConditionalGeneration"]
def _err_msg(msg: str) -> str:
logger.info(msg + "; fallback to the v0 loader for model loading.")
@@ -310,14 +403,20 @@ def temporary_dtype(dtype: str):
@contextmanager
def switch_config_context(config_obj, config_attr_name, value):
"""switch_config_context"""
origin_value = getattr(config_obj, config_attr_name)
setattr(config_obj, config_attr_name, value)
def multi_switch_config_context(*changes):
"""
changes: (obj, attr, new_value)
"""
originals = []
try:
for obj, attr, new_value in changes:
old_value = getattr(obj, attr)
originals.append((obj, attr, old_value))
setattr(obj, attr, new_value)
yield
finally:
setattr(config_obj, config_attr_name, origin_value)
for obj, attr, old_value in originals:
setattr(obj, attr, old_value)
def rename_offline_ckpt_suffix_to_fd_suffix(