mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
refactor pt loading (#4532)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -128,13 +128,35 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
|
||||
return weight_or_paramter
|
||||
|
||||
|
||||
def process_weights_after_loading(sublayers_dict: dict):
|
||||
def process_weight_transpose(layer, weight_name):
|
||||
weight = getattr(layer, weight_name)
|
||||
if len(weight.shape) == 2:
|
||||
weight_transpose = weight.transpose([1, 0])
|
||||
elif len(weight.shape) == 3:
|
||||
weight_transpose = weight.transpose([0, 2, 1])
|
||||
|
||||
weight_tmp = layer.create_parameter(
|
||||
shape=weight_transpose.shape,
|
||||
dtype=weight_transpose.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
is_bias=False,
|
||||
)
|
||||
weight_tmp.copy_(weight_transpose, False)
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
|
||||
|
||||
def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig):
|
||||
"""
|
||||
process_weights_after_loading: e.g., handle extracted weights (quantization, reshaping, etc.)
|
||||
process_weights_after_loading:
|
||||
"""
|
||||
|
||||
def fn(model_sublayer_name: str, param=None):
|
||||
from fastdeploy.model_executor.layers.linear import KVBatchLinear
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
KVBatchLinear,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
|
||||
|
||||
if model_sublayer_name not in sublayers_dict:
|
||||
return
|
||||
@@ -143,6 +165,10 @@ def process_weights_after_loading(sublayers_dict: dict):
|
||||
model_sublayer.process_weights_after_loading()
|
||||
if hasattr(model_sublayer, "quant_method"):
|
||||
quant_method = getattr(model_sublayer, "quant_method", None)
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls:
|
||||
# skip unquantized linear
|
||||
return
|
||||
if not hasattr(quant_method, "process_weights_after_loading"):
|
||||
return
|
||||
if param is not None and hasattr(param, "tensor_track") and param.tensor_track is None:
|
||||
@@ -184,6 +210,36 @@ def process_weights_before_loading(
|
||||
return fn
|
||||
|
||||
|
||||
def weight_fully_copied(weight):
|
||||
return (
|
||||
hasattr(weight, "tensor_track") and weight.tensor_track is not None and weight.tensor_track.is_fully_copied()
|
||||
)
|
||||
|
||||
|
||||
def process_final_after_loading(model, fd_config: FDConfig):
|
||||
# process_final_after_loading handles the post-loading process for cases other than dynamic quantization.
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
KVBatchLinear,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_method
|
||||
|
||||
for name, sublayer in model.named_sublayers():
|
||||
quant_method = getattr(sublayer, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
unquant_moe_cls = type(get_moe_method())
|
||||
if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls):
|
||||
continue
|
||||
if hasattr(quant_method, "process_weights_after_loading"):
|
||||
quant_method.process_weights_after_loading(sublayer)
|
||||
if isinstance(sublayer, KVBatchLinear):
|
||||
continue
|
||||
if not hasattr(sublayer, "process_weights_after_loading"):
|
||||
continue
|
||||
# Only for specific layers, such as lmhead
|
||||
sublayer.process_weights_after_loading()
|
||||
|
||||
|
||||
def free_tensor(tensor):
|
||||
if hasattr(tensor, "tensor_track"):
|
||||
tensor.tensor_track = None
|
||||
@@ -191,6 +247,15 @@ def free_tensor(tensor):
|
||||
del tensor
|
||||
|
||||
|
||||
def fd_cast(weight, param):
|
||||
if weight.dtype != param.dtype:
|
||||
if weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
weight = weight.view(param.dtype)
|
||||
else:
|
||||
weight = weight.cast(param.dtype)
|
||||
return weight
|
||||
|
||||
|
||||
def default_weight_loader(fd_config: FDConfig = None) -> None:
|
||||
"""Default weight loader"""
|
||||
|
||||
@@ -200,7 +265,6 @@ def default_weight_loader(fd_config: FDConfig = None) -> None:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if weight_need_transpose:
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
@@ -214,20 +278,15 @@ def default_weight_loader(fd_config: FDConfig = None) -> None:
|
||||
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||
if param.dtype != loaded_weight.dtype:
|
||||
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
|
||||
loaded_weight = loaded_weight.view(param.dtype)
|
||||
else:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
loaded_weight = fd_cast(loaded_weight, param)
|
||||
if param.shape != loaded_weight.shape:
|
||||
# for e_score_correction_bias
|
||||
loaded_weight = loaded_weight.reshape(param.shape)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
h2d_copy(dst=param, src=loaded_weight)
|
||||
|
||||
return fn
|
||||
|
||||
@@ -255,10 +314,44 @@ def is_paddle_support_v1_loader():
|
||||
return is_same
|
||||
|
||||
|
||||
_support_new_h2d = None
|
||||
|
||||
|
||||
def is_paddle_support_new_h2d():
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
global _support_new_h2d
|
||||
if _support_new_h2d is not None:
|
||||
return _support_new_h2d
|
||||
|
||||
code = """
|
||||
import paddle
|
||||
try:
|
||||
dst = paddle.zeros([2, 4], dtype='bfloat16')
|
||||
src = paddle.ones([2, 2], dtype='bfloat16', device='cpu')
|
||||
dst = dst[..., :2]
|
||||
dst.copy_(src)
|
||||
print(1)
|
||||
except:
|
||||
print(0)
|
||||
"""
|
||||
result = subprocess.run([sys.executable, "-c", code], capture_output=True)
|
||||
_support_new_h2d = result.stdout.strip() == b"1"
|
||||
return _support_new_h2d
|
||||
|
||||
|
||||
def h2d_copy(dst, src, blocking=True):
|
||||
if not current_platform.is_cuda() or not is_paddle_support_new_h2d():
|
||||
# For non-GPU devices, data is transferred to device (H2D) in advance.
|
||||
src = get_tensor(src)
|
||||
if not dst._is_initialized():
|
||||
dst.initialize()
|
||||
dst.copy_(src, blocking)
|
||||
|
||||
|
||||
def v1_loader_support(fd_config):
|
||||
_v1_no_support_archs = [
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
]
|
||||
_v1_no_support_archs = ["Qwen2VLForConditionalGeneration"]
|
||||
|
||||
def _err_msg(msg: str) -> str:
|
||||
logger.info(msg + "; fallback to the v0 loader for model loading.")
|
||||
@@ -310,14 +403,20 @@ def temporary_dtype(dtype: str):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_config_context(config_obj, config_attr_name, value):
|
||||
"""switch_config_context"""
|
||||
origin_value = getattr(config_obj, config_attr_name)
|
||||
setattr(config_obj, config_attr_name, value)
|
||||
def multi_switch_config_context(*changes):
|
||||
"""
|
||||
changes: (obj, attr, new_value)
|
||||
"""
|
||||
originals = []
|
||||
try:
|
||||
for obj, attr, new_value in changes:
|
||||
old_value = getattr(obj, attr)
|
||||
originals.append((obj, attr, old_value))
|
||||
setattr(obj, attr, new_value)
|
||||
yield
|
||||
finally:
|
||||
setattr(config_obj, config_attr_name, origin_value)
|
||||
for obj, attr, old_value in originals:
|
||||
setattr(obj, attr, old_value)
|
||||
|
||||
|
||||
def rename_offline_ckpt_suffix_to_fd_suffix(
|
||||
|
||||
Reference in New Issue
Block a user