Files
FastDeploy/fastdeploy/model_executor/utils.py
CSWYF3634076 acd331780c [V1 loader] Qwen25 VL support v1 loader and torch style safetensors load (#4388)
* [BugFix] qwen2.5vl enable_thinking=true and image_patch_id bug fix

* [Docs]offine infer add apply_chat_template add_generation_prompt parameter

* [Model]qwen2.5VL support --use-cudagraph

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v2

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v3

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v4

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v5

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v6

* [Model]qwen2.5VL support --use-cudagraph buffer and qwenvl test v7

* qwen25vl v1 loader

* qwen25vl v1 loader v2

* qwen25vl v1 loader v3

* qwen25vl v1 loader fix tp2 weight PySafeSlice

* qwen25vl v1 loader no test

* qwen25vl v1 loader add unit test

* qwen25vl v1 loader add unit test v2

* qwen25vl v1 loader add torch unit test v3

* qwen25vl v1 loader add torch unit test v4

* qwen25vl v1 loader add torch unit test v5

* qwen25vl v1 loader add torch unit test v6
2025-10-27 10:54:15 +08:00

370 lines
13 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import re
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import paddle
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
class BitMaskTracker:
def __init__(self, length: int):
"""
Track filling status along a single dimension using a bitmask.
Args:
length (int): Number of positions to track (e.g., columns or rows)
"""
self.length = length
self.mask = 0
def mark(self, start: int, end: int):
"""
Mark the range [start, end) as filled.
Args:
start (int): Start index (inclusive)
end (int): End index (exclusive)
"""
if start < 0 or end > self.length or start >= end:
raise ValueError("Invalid mark range")
block = ((1 << (end - start)) - 1) << start
self.mask |= block
def is_full(self) -> bool:
"""Return True if all positions are filled."""
return self.mask == (1 << self.length) - 1
class TensorTracker:
def __init__(self, shape: tuple, output_dim: int):
"""
Unified tracker for 2D or 3D tensors.
Args:
shape (tuple): Tensor shape
output_dim (bool):
- 2D: True = track columns (dim=1), False = track rows (dim=0)
- 3D: True = track columns (dim=2), False = track rows (dim=1)
"""
self.shape = shape
self.output_dim = output_dim
if len(shape) == 2:
self.track_dim = 1 if output_dim else 0
self.trackers = [BitMaskTracker(shape[self.track_dim])]
elif len(shape) == 3:
batch = shape[0]
self.track_dim = 2 if output_dim else 1
self.trackers = [BitMaskTracker(shape[self.track_dim]) for _ in range(batch)]
else:
raise ValueError("Only 2D or 3D tensors supported")
def mark(self, start: int = 0, end: int = None, batch_id: int = None):
"""
Mark a slice of the tensor as filled.
Args:
batch_id (int, optional): Batch index for 3D tensors
start (int): Start index along tracked dimension
end (int): End index along tracked dimension
"""
if end is None:
end = self.shape[self.track_dim]
if len(self.shape) == 2:
self.trackers[0].mark(start, end)
else:
if batch_id is None:
raise ValueError("batch_id must be provided for 3D tensor")
self.trackers[batch_id].mark(start, end)
def is_fully_copied(self) -> bool:
"""Return True if the tensor is fully filled along tracked dimension(s)."""
return all(tr.is_full() for tr in self.trackers)
def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
if param_attr_map is None:
return
for key, value in param_attr_map.items():
setattr(param, key, value)
def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
if hasattr(weight_or_paramter, "get_shape"):
shape = weight_or_paramter.get_shape()
else:
shape = weight_or_paramter.shape
if len(shape) == 1:
weight_or_paramter = weight_or_paramter[start:end]
elif output_dim:
weight_or_paramter = weight_or_paramter[..., start:end]
else:
weight_or_paramter = weight_or_paramter[start:end, ...]
return weight_or_paramter
def process_weights_after_loading(sublayers_dict: dict):
"""
process_weights_after_loading: e.g., handle extracted weights (quantization, reshaping, etc.)
"""
def fn(model_sublayer_name: str, param=None):
from fastdeploy.model_executor.layers.linear import KVBatchLinear
if model_sublayer_name not in sublayers_dict:
return
model_sublayer = sublayers_dict[model_sublayer_name]
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
if hasattr(model_sublayer, "quant_method"):
quant_method = getattr(model_sublayer, "quant_method", None)
if not hasattr(quant_method, "process_weights_after_loading"):
return
if param is not None and hasattr(param, "tensor_track") and not param.tensor_track.is_fully_copied():
return
quant_method.process_weights_after_loading(model_sublayer)
return fn
@dataclass
class WeightsMapper:
orig_to_new_prefix: Mapping[str, Optional[str]] = field(default_factory=dict)
def _map_name(self, key: str) -> Optional[str]:
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
key = key.replace(prefix, new_key, 1)
return key
def apply(self, weight_name):
return self._map_name(weight_name)
def process_weights_before_loading(
*, skip_prefixes: Optional[List[str]] = None, mapper: Optional[WeightsMapper] = None
):
def _can_skip(weight_name):
return any(weight_name.startswith(p) for p in (skip_prefixes or []))
def fn(weight_name):
if mapper is not None:
weight_name = mapper.apply(weight_name)
if _can_skip(weight_name):
weight_name = None
return weight_name
return fn
def free_tensor(tensor):
if hasattr(tensor, "tensor_track"):
tensor.tensor_track = None
tensor.value().get_tensor()._clear()
del tensor
def default_weight_loader(fd_config: FDConfig = None) -> None:
"""Default weight loader"""
def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
"""fn"""
output_dim = getattr(param, "output_dim", None)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = get_tensor(loaded_weight)
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1:
dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // fd_config.parallel_config.tensor_parallel_size
shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
loaded_weight = loaded_weight.view(param.dtype)
else:
loaded_weight = loaded_weight.cast(param.dtype)
if param.shape != loaded_weight.shape:
# for e_score_correction_bias
loaded_weight = loaded_weight.reshape(param.shape)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
return fn
def is_pre_sliced_weight(model_path):
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
return len(rank_dirs) > 1
def is_paddle_support_v1_loader():
src_shape = [32, 32]
tgt_shape = [1, 32, 64]
src_tensor = paddle.ones(src_shape, dtype="float32")
tgt_tensor = paddle.zeros(tgt_shape, dtype="float32")
for exp_id in range(tgt_shape[0]):
# gate
gate_tgt = tgt_tensor[exp_id][..., : tgt_shape[2] // 2]
gate_tgt.copy_(src_tensor, False)
# up
up_tgt = tgt_tensor[exp_id][..., tgt_shape[2] // 2 :]
up_tgt.copy_(src_tensor, False)
is_same = bool(paddle.all(tgt_tensor == 1))
return is_same
def v1_loader_support(fd_config):
_v1_no_support_archs = [
"Qwen2VLForConditionalGeneration",
]
def _err_msg(msg: str) -> str:
logger.info(msg + "; fallback to the v0 loader for model loading.")
if not current_platform.is_cuda():
_err_msg("v1loader currently does not support backends other than CUDA")
return False
if is_pre_sliced_weight(fd_config.model_config.model):
_err_msg("v1 loader currently does not support pre-sliced weights")
return False
if fd_config.parallel_config.use_ep:
_err_msg("v1 loader currently does not support expert parallelism")
return False
if envs.FD_MOE_BACKEND.lower() == "marlin":
_err_msg("v1 loader currently does not support marlin backend")
return False
if fd_config.quant_config is not None:
if fd_config.quant_config.name() == "mix_quant":
moe_quant_type = fd_config.quant_config.moe_quant_type
dense_quant_type = fd_config.quant_config.dense_quant_type
else:
moe_quant_type = fd_config.quant_config.name()
dense_quant_type = fd_config.quant_config.name()
unsupported_quant = {"w4a8", "w4afp8", "wint2"}
if unsupported_quant & {moe_quant_type, dense_quant_type}:
_err_msg("v1 loader currently does not support w4a8/w4afp8/win2 quantization")
return False
if fd_config.model_config.architectures[0] in _v1_no_support_archs:
_err_msg(f"v1 loader currently does not support {fd_config.model_config.architectures[0]}")
return False
if not is_paddle_support_v1_loader():
_err_msg("The installed Paddle does not support v1 loader")
return False
return True
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""
orig_dtype = paddle.get_default_dtype()
try:
if dtype is not None and dtype == "float32":
paddle.set_default_dtype(dtype)
yield
finally:
paddle.set_default_dtype(orig_dtype)
@contextmanager
def switch_config_context(config_obj, config_attr_name, value):
"""switch_config_context"""
origin_value = getattr(config_obj, config_attr_name)
setattr(config_obj, config_attr_name, value)
try:
yield
finally:
setattr(config_obj, config_attr_name, origin_value)
def rename_offline_ckpt_suffix_to_fd_suffix(
fd_config, ckpt_weight_suffix: str = "quant_weight", ckpt_scale_suffix="weight_scale"
):
"""
Create a function to rename checkpoint key suffixes for FastDeploy.
Replaces the original suffix (default "weight_scale") with the FD target
suffix (default "quant_weight"). Only the suffix is changed.
Args:
fd_config: FastDeploy configuration.
ckpt_weight_suffix: Original checkpoint key suffix.
ckpt_scale_suffix: Target FastDeploy key suffix.
Returns:
Callable: Function that renames checkpoint keys.
"""
fd_suffix_map = {} # noqa: F841
fp8_suffix_map = {
ckpt_weight_suffix: "weight",
ckpt_scale_suffix: "weight_scale_inv",
}
moe_quant_type = ""
dense_quant_type = ""
if fd_config.quant_config is not None:
if fd_config.quant_config.name() == "mix_quant":
moe_quant_type = fd_config.quant_config.moe_quant_type
dense_quant_type = fd_config.quant_config.dense_quant_type
else:
moe_quant_type = fd_config.quant_config.name()
dense_quant_type = fd_config.quant_config.name()
def fn(loaded_weight_name, is_moe):
if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16:
return loaded_weight_name
# Can be extended to other offline quantization suffixes if needed.
if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"):
fd_suffix_map = fp8_suffix_map
for ckpt_suffix, fd_suffix in fd_suffix_map.items():
if re.search(rf"{ckpt_suffix}$", loaded_weight_name):
loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix)
return loaded_weight_name
return loaded_weight_name
return fn