diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 47ce9365f..d6958a919 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -348,8 +348,7 @@ class ColumnParallelLinear(LinearBase): if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=1) - if self.nranks > 1: - set_weight_attrs(self.bias, {"output_dim": True}) + set_weight_attrs(self.bias, {"output_dim": True}) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -404,6 +403,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) + assert output_dim is not None shard_dim = -1 if output_dim else 0 output_size = param.shape[shard_dim] if loaded_shard_id is None: @@ -517,11 +517,12 @@ class QKVParallelLinear(ColumnParallelLinear): with_bias=with_bias, add_bias=add_bias, ) - setattr(self.weight, "output_dim", True) def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) - head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) + assert output_dim is not None + dim = -1 if output_dim else 0 + head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) if loaded_shard_id is None: # Loaded weight is already fused on disk shard_offsets = [ @@ -540,7 +541,6 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_shard_id in ["q", "k", "v"] # Tensor parallelism splits the weight along the output_dim if self.nranks != 1: - dim = -1 if output_dim else 0 if isinstance(loaded_weight, np.ndarray): size = loaded_weight.shape[dim] else: @@ -717,13 +717,12 @@ class RowParallelLinear(LinearBase): if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=0) - if self.nranks > 1: - set_weight_attrs( - self.bias, - { - "output_dim": False, - }, - ) + set_weight_attrs( + self.bias, + { + "output_dim": False, + }, + ) self.reduce_results = reduce_results diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 69f7dd81a..3682b5dc1 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -16,6 +16,7 @@ from __future__ import annotations +import re from functools import partial import paddle @@ -314,7 +315,10 @@ class Qwen2ForCausalLM(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -328,6 +332,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): ] params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: @@ -340,11 +345,14 @@ class Qwen2ForCausalLM(ModelForCasualLM): weight_loader(param, loaded_weight, shard_id) break else: - if loaded_weight_name not in params_dict: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: continue - param = params_dict[loaded_weight_name] + param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) @classmethod def name(self): diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py index e4eec9925..b8b005f02 100644 --- a/tests/model_loader/test_common_model.py +++ b/tests/model_loader/test_common_model.py @@ -99,6 +99,9 @@ model_param_map = { "tensor_parallel_size": 2, "quantizations": ["wint8"], }, + "Qwen2-7B-Instruct": { + "quantizations": ["None", "wint8"], + }, } params = []