[V1 Loader] Support qwen2(bf16) (#3502)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support qwen2(bf16)

* merge bias_loader and weight_loader
This commit is contained in:
Zero Rains
2025-08-23 01:08:23 +08:00
committed by GitHub
parent cb166053ba
commit 79f0dbbb55
3 changed files with 76 additions and 13 deletions

View File

@@ -26,6 +26,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
from fastdeploy.model_executor.models.utils import (
default_weight_loader,
set_weight_attrs,
slice_fn,
)
from fastdeploy.platforms import current_platform
@@ -159,6 +160,11 @@ class LinearBase(nn.Layer):
dtype=self._dtype,
is_bias=True,
)
setattr(
self.bias,
"weight_loader",
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config),
)
# smooth quant
self.linear_shift = None
@@ -503,8 +509,10 @@ class QKVParallelLinear(ColumnParallelLinear):
with_bias=with_bias,
add_bias=add_bias,
)
setattr(self.weight, "output_dim", True)
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
# Loaded weight is already fused on disk
if self.nranks != 1:
@@ -515,7 +523,9 @@ class QKVParallelLinear(ColumnParallelLinear):
("v", (self.num_heads + self.kv_num_heads) * self.head_dim, self.kv_num_heads * self.head_dim),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight[..., shard_offset : shard_offset + shard_size]
loaded_weight_shard = loaded_weight_shard = slice_fn(
loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
loaded_weight = get_tensor(loaded_weight)
@@ -525,10 +535,9 @@ class QKVParallelLinear(ColumnParallelLinear):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
dim = -1
dim = -1 if output_dim else 0
if isinstance(loaded_weight, np.ndarray):
size = loaded_weight.shape[dim]
else:
@@ -541,17 +550,16 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = get_tensor(loaded_weight)
if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
param_shard_offset = 0
param_shard_size = self.num_heads_per_rank * self.head_dim
elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
param_shard_offset = self.num_heads_per_rank * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
else:
# loaded_shard_id == "v"
param_shard_offset = (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim
param_shard_size = self.kv_num_heads_per_rank * self.head_dim
param = slice_fn(param, output_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)

View File

@@ -305,6 +305,47 @@ class Qwen2ForCausalLM(ModelForCasualLM):
prefix="lm_head",
)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
params_dict = dict(self.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
if loaded_weight_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
@classmethod
def name(self):
""" """

View File

@@ -54,6 +54,20 @@ def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]):
setattr(param, key, value)
def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
if hasattr(weight_or_paramter, "get_shape"):
shape = weight_or_paramter.get_shape()
else:
shape = weight_or_paramter.shape
if len(shape) == 1:
weight_or_paramter = weight_or_paramter[start:end]
elif output_dim:
weight_or_paramter = weight_or_paramter[..., start:end]
else:
weight_or_paramter = weight_or_paramter[start:end, ...]
return weight_or_paramter
def default_weight_loader(fd_config: FDConfig) -> None:
"""Default weight loader"""