""" # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from typing import Dict, Optional import numpy as np import paddle from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig from fastdeploy.model_executor.models.utils import set_weight_attrs from .utils import get_tensor class ParallelLMHead(nn.Layer): """ "Parallelized LM head. """ def __init__( self, fd_config: FDConfig, num_embeddings: int, embedding_dim: int, prefix: str = "", with_bias: bool = False, ) -> None: """ Parallelized LMhead. Args: fd_config (FDConfig): Arguments related to inference, containing attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, num_attention_heads, and ffn_hidden_size. num_embeddings (int): vocabulary size. embedding_dim (int): size of hidden state. prefix (str): The name of current layer. Defaults to "". with_bias (bool): whether to have bias. Default: False. """ super(ParallelLMHead, self).__init__() self.weight_key: str = prefix + ".weight" if with_bias: self.bias_key: Optional[str] = prefix + ".bias" else: self.bias_key: Optional[str] = None self.use_ep: bool = fd_config.parallel_config.use_ep self.column_cut = True self.nranks = fd_config.parallel_config.tensor_parallel_size ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings if self.use_ep: self.weight = self.create_parameter( shape=[embedding_dim, num_embeddings], dtype=paddle.get_default_dtype(), is_bias=False, ) if self.bias_key is not None: self.bias = self.create_parameter( shape=[num_embeddings], dtype=paddle.get_default_dtype(), is_bias=True, ) else: if self.column_cut: need_gather = True self.linear = ColumnParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), weight_attr=None, has_bias=True if self.bias_key is not None else False, gather_output=need_gather, fuse_matmul_bias=False, # False diff更小 ) if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) else: self.linear = RowParallelLinear( embedding_dim, num_embeddings, mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), weight_attr=None, has_bias=True if self.bias_key is not None else False, input_is_parallel=False, fuse_matmul_bias=False, # False diff更小 ) if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": False}) def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ Load the checkpoint state dictionary into the layer. Args: state_dict (dict): A dictionary containing the checkpoint weights and biases. """ if self.use_ep: self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) if self.bias_key is not None: self.bias.set_value( get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype()) ) else: if self.tie_word_embeddings: self.linear.weight.set_value( get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0]) ) else: weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) self.linear.weight.set_value(weight_tensor) if self.bias_key is not None: bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) self.linear.bias.set_value(bias) def forward(self, input: paddle.Tensor) -> paddle.Tensor: """ Defines the forward computation of the layer. Args: input (Tensor): The input tensor to the layer. Returns: Tensor: The output tensor after processing through the layer. """ logits = input if self.use_ep: if self.linear_bias_key is None: logits = paddle.matmul(logits, self.weight) else: logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias) else: logits = self.linear(logits) return logits