diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 5f4291faf..b3fbaa754 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -959,7 +959,7 @@ class KVBatchLinear(nn.Layer): # Split num_attention_heads when using TP inference. self.num_heads_per_partition = divide(num_attention_heads, self.nranks) self.local_rank = fd_config.parallel_config.tensor_parallel_rank - + self.fd_config = fd_config self.kv_b_proj = kv_b_proj self.weight_dtype = self._helper.get_default_dtype() @@ -968,7 +968,8 @@ class KVBatchLinear(nn.Layer): self.weight_key = f"{prefix}.weight" # e.g., "kv_b_proj.weight" def process_weights_after_loading(self): - + if self.fd_config.load_config.dynamic_load_weight: + return w = self.kv_b_proj.weight.reshape( [ self.kv_lora_rank, diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index bfc544ffb..d8857f984 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -26,12 +26,7 @@ from fastdeploy.model_executor.layers.utils import ( DEFAULT_VOCAB_PADDING_SIZE, pad_vocab_size, ) -from fastdeploy.model_executor.utils import ( - default_weight_loader, - free_tensor, - set_weight_attrs, - temporary_dtype, -) +from fastdeploy.model_executor.utils import set_weight_attrs, temporary_dtype from .utils import get_tensor @@ -80,7 +75,6 @@ class ParallelLMHead(nn.Layer): if num_embeddings % self.nranks != 0: num_embeddings = pad_vocab_size(num_embeddings, self.padding_size) self.num_embeddings = num_embeddings - self.model_format = fd_config.model_config.model_format ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear @@ -90,39 +84,21 @@ class ParallelLMHead(nn.Layer): self.need_gather = True with temporary_dtype(self.dtype): - if self.fd_config.load_config.load_choices == "default_v1" and ( - self.model_format == "torch" or self.tie_word_embeddings - ): - self.linear = RowParallelLinear( - num_embeddings, - embedding_dim, - mp_group=self.tp_group, - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - input_is_parallel=False, - fuse_matmul_bias=False, - ) - set_weight_attrs( - self.linear.weight, - { - "weight_loader": default_weight_loader(self.fd_config), - }, - ) - set_weight_attrs(self.linear.weight, {"output_dim": False}) - elif self.column_cut: + if self.column_cut: + need_gather = True self.linear = ColumnParallelLinear( embedding_dim, num_embeddings, mp_group=self.tp_group, weight_attr=None, has_bias=True if self.bias_key is not None else False, - gather_output=self.need_gather, + gather_output=need_gather, fuse_matmul_bias=False, ) set_weight_attrs( self.linear.weight, { - "weight_loader": default_weight_loader(self.fd_config), + "weight_need_transpose": self.fd_config.model_config.model_format == "torch", }, ) set_weight_attrs(self.linear.weight, {"output_dim": True}) @@ -139,34 +115,11 @@ class ParallelLMHead(nn.Layer): set_weight_attrs( self.linear.weight, { - "weight_loader": default_weight_loader(self.fd_config), + "weight_need_transpose": self.fd_config.model_config.model_format == "torch", }, ) set_weight_attrs(self.linear.weight, {"output_dim": False}) - def process_weights_after_loading(self): - if not ( - self.fd_config.load_config.load_choices == "default_v1" - and (self.model_format == "torch" or self.tie_word_embeddings) - ): - return - if not self.linear.weight._is_initialized(): - self.linear.weight.initialize() - weight_transpose = self.linear.weight.transpose([1, 0]) - with temporary_dtype(self.dtype): - linear = fleet.meta_parallel.ColumnParallelLinear( - self.embedding_dim, - self.num_embeddings, - mp_group=self.tp_group, - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - gather_output=self.need_gather, - fuse_matmul_bias=False, - ) - linear.weight.set_value(weight_transpose) - free_tensor(self.linear.weight) - self.linear = linear - def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 1975bb375..b3155e100 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -1422,13 +1422,6 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): down_proj_weight_name = self.added_weight_attrs[1] up_gate_proj_scale_name = self.added_scale_attrs[0] down_proj_scale_name = self.added_scale_attrs[1] - if ( - not weight_fully_copied(getattr(layer, up_gate_proj_weight_name)) - or not weight_fully_copied(getattr(layer, down_proj_weight_name)) - or not weight_fully_copied(getattr(layer, up_gate_proj_scale_name)) - or not weight_fully_copied(getattr(layer, down_proj_scale_name)) - ): - return process_weight_transpose(layer, up_gate_proj_weight_name) process_weight_transpose(layer, down_proj_weight_name) process_weight_transpose(layer, up_gate_proj_scale_name) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index ea3497478..2920f2d51 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -63,7 +63,7 @@ def get_moe_method(): ) return MetaxCutlassUnquantizedFusedMoEMethod(None) - raise NotImplementedError + return None def get_moe_scores( @@ -189,7 +189,9 @@ class FusedMoE(nn.Layer): self.quant_method = moe_quant_config.get_quant_method(self) self.moe_quant_type = moe_quant_config.name() else: + # unquantized quant_method self.quant_method = get_moe_method() + assert self.quant_method is not None, "self.quant_method should not be None" self.redundant_table_manger = redundant_table_manger if self.ep_size > 1: self.quant_method.init_ep(self) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 533e0061f..408607b10 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -62,10 +62,13 @@ def load_weights_from_cache(model, weights_iterator): logger.info(f"{loaded_weight_name} is not in model parameters.") continue param = params_dict[loaded_weight_name] + if param.shape != loaded_weight.shape: + raise ValueError( + f"Shape mismatch between loaded weight {loaded_weight_name}: {loaded_weight.shape}, expected shape: {param.shape}" + ) param.copy_(loaded_weight, False) if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False): - model.lm_head.linear.weight.set_value(loaded_weight) - model.lm_head.process_weights_after_loading() + model.lm_head.linear.weight.set_value(loaded_weight.transpose([1, 0])) for _, model_sublayer in model.named_sublayers(): if isinstance(model_sublayer, KVBatchLinear): model_sublayer.process_weights_after_loading() @@ -107,7 +110,6 @@ def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"): weight_cache_context = multi_switch_config_context( (fd_config.quant_config, "is_checkpoint_bf16", False), - (fd_config.model_config, "model_format", "paddle"), ) return enable_cache, weight_cache_dir, weight_cache_context diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index d688e1dde..8fb0ebf38 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -56,8 +56,8 @@ class DefaultModelLoaderV1(BaseModelLoader): load_weights_from_cache(model, weights_iterator) else: model.load_weights(weights_iterator) - if fd_config.speculative_config.model_type != "mtp": - process_final_after_loading(model, fd_config) + if fd_config.speculative_config.model_type != "mtp": + process_final_after_loading(model, fd_config) self.clean_memory_fragments() @@ -76,6 +76,7 @@ class DefaultModelLoaderV1(BaseModelLoader): architectures = architectures + "RL" enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config) + fd_config.model_config.enable_cache = enable_cache with weight_cache_context: with context: model_cls = ModelRegistry.get_class(architectures) @@ -88,6 +89,8 @@ class DefaultModelLoaderV1(BaseModelLoader): assert_never(convert_type) model = model_cls(fd_config) + if fd_config.load_config.dynamic_load_weight or fd_config.model_config.enable_cache: + process_final_after_loading(model, fd_config) model.eval() # RL model not need set_state_dict diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 7f3855160..75947590b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -600,7 +600,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index be7744131..a291db0e9 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -718,7 +718,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): ) process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) @paddle.no_grad() def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 69b010c0b..d49f3a327 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -377,7 +377,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(self.qwen2.embed_tokens.embeddings.weight.transpose([1, 0])) @classmethod def name(self): diff --git a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py index 53f5c766e..531f530c4 100644 --- a/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py +++ b/fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py @@ -231,7 +231,7 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) @paddle.no_grad() def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 8b1004d76..4b6e27808 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -320,7 +320,7 @@ class Qwen3ForCausalLM(ModelForCasualLM): process_weights_after_loading_fn(model_sublayer_name, param) if self.tie_word_embeddings and not is_pooling_model: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index e04341061..3b42e0294 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -131,16 +131,24 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1): def process_weight_transpose(layer, weight_name): weight = getattr(layer, weight_name) if len(weight.shape) == 2: - weight_transpose = weight.transpose([1, 0]) + weight_shape = weight.shape[::-1] elif len(weight.shape) == 3: - weight_transpose = weight.transpose([0, 2, 1]) - + weight_shape = [weight.shape[0]] + list(weight.shape[1:][::-1]) weight_tmp = layer.create_parameter( - shape=weight_transpose.shape, - dtype=weight_transpose.dtype, + shape=weight_shape, + dtype=weight.dtype, default_initializer=paddle.nn.initializer.Constant(0), is_bias=False, ) + if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache: + free_tensor(weight) + setattr(layer, weight_name, weight_tmp) + return + + if len(weight.shape) == 2: + weight_transpose = weight.transpose([1, 0]) + elif len(weight.shape) == 3: + weight_transpose = weight.transpose([0, 2, 1]) weight_tmp.copy_(weight_transpose, False) free_tensor(weight) setattr(layer, weight_name, weight_tmp) @@ -163,9 +171,16 @@ def process_weights_after_loading(sublayers_dict: dict, fd_config: FDConfig): model_sublayer = sublayers_dict[model_sublayer_name] if isinstance(model_sublayer, KVBatchLinear): model_sublayer.process_weights_after_loading() + if fd_config.quant_config and not fd_config.quant_config.is_checkpoint_bf16: + # skip for offline quantization + return if hasattr(model_sublayer, "quant_method"): quant_method = getattr(model_sublayer, "quant_method", None) - unquant_moe_cls = type(get_moe_method()) + unquant_moe_layer = get_moe_method() + if unquant_moe_layer is None: + unquant_moe_cls = object + else: + unquant_moe_cls = type(unquant_moe_layer) if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls: # skip unquantized linear return @@ -225,18 +240,23 @@ def process_final_after_loading(model, fd_config: FDConfig): from fastdeploy.model_executor.layers.moe.moe import get_moe_method for name, sublayer in model.named_sublayers(): - quant_method = getattr(sublayer, "quant_method", None) - if quant_method is not None: - unquant_moe_cls = type(get_moe_method()) - if not (type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls): - continue - if hasattr(quant_method, "process_weights_after_loading"): - quant_method.process_weights_after_loading(sublayer) if isinstance(sublayer, KVBatchLinear): continue + quant_method = getattr(sublayer, "quant_method", None) + if quant_method is not None: + unquant_moe_layer = get_moe_method() + if unquant_moe_layer is None: + unquant_moe_cls = object + else: + unquant_moe_cls = type(unquant_moe_layer) + is_unquant_cls = type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls + is_offline_quantized_ckpt = not (fd_config.quant_config and fd_config.quant_config.is_checkpoint_bf16) + if is_unquant_cls or is_offline_quantized_ckpt: + if hasattr(quant_method, "process_weights_after_loading"): + quant_method.process_weights_after_loading(sublayer) + continue if not hasattr(sublayer, "process_weights_after_loading"): continue - # Only for specific layers, such as lmhead sublayer.process_weights_after_loading()