diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index f1ea6572f..918450c74 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -17,7 +17,6 @@ import paddle from paddle import nn -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.utils import ceil_div @@ -241,7 +240,4 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): intermediate_cache3.reshape_([token_num, top_k, hidden_size]) out = intermediate_cache3.sum(axis=1) - - if layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out) return out diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index c13a68f31..e67dd6dbd 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -175,13 +175,6 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): fused_moe_out = intermediate_cache3.sum(axis=1) fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size]) - if layer.tp_size > 1: - from fastdeploy.distributed.communication import ( - tensor_model_parallel_all_reduce, - ) - - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out) - return fused_moe_out def apply( diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py index 405beb1de..db4659da9 100644 --- a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -211,7 +211,7 @@ class HPUAttentionBackend(AttentionBackend_HPU): self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp" self.rank: int = llm_config.parallel_config.tensor_parallel_rank - self.nranks = llm_config.parallel_config.tensor_parallel_size + self.tp_size = llm_config.parallel_config.tensor_parallel_size self.kv_num_heads = kv_num_heads self.num_heads = num_heads @@ -325,7 +325,7 @@ class HPUAttentionBackend(AttentionBackend_HPU): softmax_mode=0, ) - if self.nranks > 1: + if self.tp_size > 1: from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce_custom, ) @@ -368,7 +368,7 @@ class HPUAttentionBackend(AttentionBackend_HPU): ) # all_reduce - if self.nranks > 1: + if self.tp_size > 1: from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce_custom, ) diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py index 3d354df99..d803e3d31 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py @@ -20,7 +20,6 @@ import paddle from paddle import nn from paddle.nn.quant import weight_quantize -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( MoEMethodBase, UnquantizedFusedMoEMethod, @@ -171,9 +170,6 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): False, ) - if layer.reduce_results and layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group) - return fused_moe_out @@ -301,9 +297,6 @@ class MetaxCutlassMoEMethod(MoEMethodBase): False, ) - if layer.reduce_results and layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group) - return fused_moe_out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index 2ef470541..7b61d58b6 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -18,7 +18,6 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess @@ -393,6 +392,4 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) - if layer.reduce_results and layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group) return out diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 2a7d48460..3a14e28e3 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -17,7 +17,6 @@ import paddle from paddle import nn -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig from fastdeploy.model_executor.layers.utils import get_tensor @@ -255,8 +254,6 @@ class XPUMoEMethod(MoEMethodBase): layer.top_k, False, # moe group, used in deepseek ) - if layer.reduce_results and layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out) return fused_moe_out @@ -314,8 +311,6 @@ class XPUMoEMethod(MoEMethodBase): permute_indices_per_token.shape[1], ) - if layer.reduce_results and layer.tp_size > 1: - tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out) return tmp_ffn_out def apply_tp( diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 226f4e14c..0fde31096 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -79,7 +79,6 @@ class UnquantizedLinearMethod(QuantMethodBase): layer.weight.set_value(weights) def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor: - linear_out = paddle.matmul(x, layer.weight) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) @@ -423,9 +422,9 @@ class ColumnParallelLinear(LinearBase): skip_quant (bool): Whether to skip quantization. Defaults to False. """ self.fd_config = fd_config - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.input_size = input_size - self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference. + self.output_size = divide(output_size, self.tp_size) # Split the output_size using TP inference. self.hidden_size = fd_config.model_config.hidden_size super().__init__( @@ -449,7 +448,7 @@ class ColumnParallelLinear(LinearBase): model_format=fd_config.model_config.model_format, ) - if self.nranks > 0: + if self.tp_size > 0: if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=1) @@ -492,7 +491,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ self.activation = activation self.hidden_size = fd_config.model_config.hidden_size - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.output_size = output_size self.local_rank = fd_config.parallel_config.tensor_parallel_rank @@ -522,8 +521,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Loaded weight is already fused on disk. shard_offsets = [ # (shard_id, shard_offset, shard_size) - ("gate", 0, output_size * self.nranks // 2), - ("up", output_size * self.nranks // 2, output_size * self.nranks // 2), + ("gate", 0, output_size * self.tp_size // 2), + ("up", output_size * self.tp_size // 2, output_size * self.tp_size // 2), ] for shard_id, shard_offset, shard_size in shard_offsets: loaded_weight_shard = slice_fn( @@ -537,13 +536,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim - if self.nranks > 1 and output_dim is not None: + if self.tp_size > 1 and output_dim is not None: dim = -1 if output_dim else 0 if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] - block_size = size // self.nranks + block_size = size // self.tp_size shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) @@ -635,15 +634,15 @@ class QKVParallelLinear(ColumnParallelLinear): self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.local_rank = fd_config.parallel_config.tensor_parallel_rank - self.num_heads_per_rank = divide(self.num_heads, self.nranks) - if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0: + self.num_heads_per_rank = divide(self.num_heads, self.tp_size) + if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0: self.kv_num_heads_per_rank = 1 - self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads) - output_size = (self.num_heads + 2 * self.nranks) * self.head_dim + self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads) + output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim else: - self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks) + self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size) self.num_kv_head_replicas = 1 output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim input_size = self.hidden_size @@ -697,7 +696,7 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim - if self.nranks > 1 and output_dim is not None: + if self.tp_size > 1 and output_dim is not None: block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim) shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas shard_offset = shard_id * block_size @@ -750,10 +749,10 @@ class QKVParallelLinear(ColumnParallelLinear): k_tensor = get_tensor(state_dict.pop(k_weight_key)) v_tensor = get_tensor(state_dict.pop(v_weight_key)) - if self.kv_num_heads < self.nranks: + if self.kv_num_heads < self.tp_size: sharedkv_index = ( self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads - ) // self.nranks + ) // self.tp_size sharedkv_start = sharedkv_index * self.head_dim sharedkv_end = sharedkv_start + self.head_dim k_tensor = k_tensor[:, sharedkv_start:sharedkv_end] @@ -767,10 +766,7 @@ class QKVParallelLinear(ColumnParallelLinear): ) weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0]) - if self.fd_config.quant_config: - self.quant_method.process_loaded_weights(self, weight_tensor) - else: - self.weight.set_value(weight_tensor) + self.quant_method.process_loaded_weights(self, weight_tensor) def load_state_dict(self, state_dict: dict): """ @@ -846,10 +842,8 @@ class RowParallelLinear(LinearBase): skip_quant (bool): Whether to skip quantization. Defaults to False. """ self.fd_config = fd_config - self.skip_quant = False self.ep_size = fd_config.parallel_config.expert_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size - self.nranks = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group self.hidden_size = fd_config.model_config.hidden_size self.head_dim = fd_config.model_config.head_dim @@ -863,7 +857,7 @@ class RowParallelLinear(LinearBase): if self.split_token: self.input_size = input_size else: - self.input_size = divide(input_size, self.nranks) + self.input_size = divide(input_size, self.tp_size) self.output_size = output_size super().__init__( @@ -876,8 +870,7 @@ class RowParallelLinear(LinearBase): skip_quant=skip_quant, weight_dtype=weight_dtype, ) - if add_bias: - assert with_bias, "with_bias must be True when add_bias is True." + assert self.quant_method is not None create_weight_kwargs = dict( layer=self, @@ -887,12 +880,17 @@ class RowParallelLinear(LinearBase): ), model_format=fd_config.model_config.model_format, ) - if self.nranks > 0: + if self.tp_size > 1: create_weight_kwargs["split_axis"] = 0 create_weight_kwargs["is_distributed"] = True self.quant_method.create_weights(**create_weight_kwargs) - self.reduce_results = reduce_results + self.reduce_results = reduce_results and not self.split_token + + if add_bias: + assert with_bias, "with_bias must be True when add_bias is True." + if self.tp_size > 1 and self.reduce_results: + set_weight_attrs(self.bias, {"tp_row_bias": True}) def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: token_num = x.shape[0] @@ -912,15 +910,11 @@ class RowParallelLinear(LinearBase): if self.split_token: x = self.all2all_transpose(x) - if self.fd_config.quant_config: - out = self.quant_method.apply(self, x) - else: - out = paddle.matmul(x, self.weight) + out = self.quant_method.apply(self, x) - if self.reduce_results and self.nranks > 1 and not self.split_token: + if self.reduce_results and self.tp_size > 1: out = tensor_model_parallel_all_reduce(out, self.tp_group) - if not self.fd_config.quant_config and self.add_bias: - out = paddle.add(out, self.bias) + return out @@ -950,16 +944,15 @@ class KVBatchLinear(nn.Layer): qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None. v_head_dim (int): Dimension for V projection. Defaults to None. with_bias (bool): Whether to include bias or not. Defaults to False. - skip_quant (bool): Whether to skip quantization. Defaults to False. """ super().__init__() - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.kv_lora_rank = kv_lora_rank self.num_attention_heads = num_attention_heads self.qk_nope_head_dim = qk_nope_head_dim self.v_head_dim = v_head_dim # Split num_attention_heads when using TP inference. - self.num_heads_per_partition = divide(num_attention_heads, self.nranks) + self.num_heads_per_partition = divide(num_attention_heads, self.tp_size) self.local_rank = fd_config.parallel_config.tensor_parallel_rank self.fd_config = fd_config self.kv_b_proj = kv_b_proj diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index d8857f984..ff2797a04 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -68,11 +68,11 @@ class ParallelLMHead(nn.Layer): self.embedding_dim = embedding_dim self.tp_group = fd_config.parallel_config.tp_group self.column_cut = True - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.fd_config = fd_config self.padding_size = padding_size - if num_embeddings % self.nranks != 0: + if num_embeddings % self.tp_size != 0: num_embeddings = pad_vocab_size(num_embeddings, self.padding_size) self.num_embeddings = num_embeddings diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index b83cc339e..d87894b81 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -20,7 +20,6 @@ from paddle.nn.quant import weight_quantize from paddleformers.utils.log import logger import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.platforms import current_platform from ..utils import get_tensor @@ -390,9 +389,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): routed_scaling_factor=1.0, ) - if layer.reduce_results and layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group) - return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index ba4fdb7cc..4e591f8e0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -19,7 +19,6 @@ from paddle import nn from paddleformers.utils.log import logger import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm @@ -423,7 +422,5 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): False, # norm_topk_prob 1.0, )[0] - if layer.tp_size > 1: - tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out) return tmp_ffn_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index ca2f4bd25..094d3df8f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -18,7 +18,6 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.ops.gpu import ( MoeWna16MarlinGemmApi, tritonmoe_preprocess_func, @@ -351,7 +350,4 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): ffn_out.reshape_([token_num, -1, hidden_size]) ffn_out = ffn_out.sum(axis=1) - if layer.reduce_results and layer.tp_size > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - return ffn_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index e26a051a7..3c1485937 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -18,7 +18,6 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -433,8 +432,6 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) - if layer.reduce_results and layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out) return out @@ -838,9 +835,6 @@ class Wfp8Afp8MoEMethod(QuantMethodBase): down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) - if layer.reduce_results and layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out) - return out @@ -1129,9 +1123,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) - if layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out) - return out @@ -1625,7 +1616,4 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): intermediate_cache3.reshape_([token_num, top_k, hidden_size]) out = intermediate_cache3.sum(axis=1) - if layer.tp_size > 1: - out = tensor_model_parallel_all_reduce(out) - return out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index 43e58a6f1..f75e36bcb 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -18,7 +18,6 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce from fastdeploy.utils import ceil_div @@ -316,9 +315,6 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): routed_scaling_factor=1.0, ) - if layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out) - return fused_moe_out @@ -486,7 +482,4 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): fused_moe_out = paddle.sum(intermediate_cache3, axis=1) - if layer.tp_size > 1: - fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out) - return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 223c3f84b..e99356e6b 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -21,6 +21,7 @@ from paddle import nn from paddleformers.utils.log import logger from fastdeploy import envs +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn from fastdeploy.platforms import current_platform @@ -634,4 +635,7 @@ class FusedMoE(nn.Layer): out = self.forward_split_allgather(x, gate) else: out = self.quant_method.apply(self, x, gate) + + if self.reduce_results and self.tp_size > 1: + out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index c51523ff1..42493a1d3 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -56,7 +56,7 @@ class ParallelEHProjection(nn.Layer): self.fd_config = fd_config self.tp_group = fd_config.parallel_config.tp_group self.column_cut = True - self.nranks = fd_config.parallel_config.tensor_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear @@ -84,7 +84,7 @@ class ParallelEHProjection(nn.Layer): self.linear.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}, ) - if self.nranks > 1: + if self.tp_size > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) else: self.linear = RowParallelLinear( @@ -103,7 +103,7 @@ class ParallelEHProjection(nn.Layer): "weight_need_transpose": self.fd_config.model_config.model_format == "torch", }, ) - if self.nranks > 1: + if self.tp_size > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) set_weight_attrs( self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 436b03395..b2ec3fbc9 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -66,7 +66,6 @@ class Ernie4_5_MLP(nn.Layer): reduce_results: bool = True, ) -> None: super().__init__() - self.nranks = fd_config.parallel_config.tensor_parallel_size self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index d49f3a327..0a84248b9 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -61,7 +61,6 @@ class Qwen2MLP(nn.Layer): prefix: str = "", ) -> None: super().__init__() - self.nranks = fd_config.parallel_config.tensor_parallel_size self.up_gate_proj = MergedColumnParallelLinear( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 4b6e27808..3fb20da95 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -59,7 +59,6 @@ class Qwen3Attention(nn.Layer): self.head_dim = fd_config.model_config.head_dim self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False) - nranks = fd_config.parallel_config.tensor_parallel_size self.o_proj = RowParallelLinear( fd_config, @@ -91,10 +90,10 @@ class Qwen3Attention(nn.Layer): begin_norm_axis=2, ) - nranks = fd_config.parallel_config.tensor_parallel_size - num_kv_heads_replicas = max(1, nranks // fd_config.model_config.num_key_value_heads) - self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks - self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // nranks + tp_size = fd_config.parallel_config.tensor_parallel_size + num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads) + self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size + self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size def load_state_dict(self, state_dict): """ """ diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 9537b84f2..3e9a72d76 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -97,8 +97,6 @@ class Qwen3MLP(nn.Layer): prefix: str = "", ) -> None: super().__init__() - self.nranks = fd_config.parallel_config.tensor_parallel_size - self.up_gate_proj = MergedColumnParallelLinear( fd_config, prefix=f"{prefix}.up_gate_proj", diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 3b42e0294..aec4d550f 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -298,6 +298,10 @@ def default_weight_loader(fd_config: FDConfig = None) -> None: shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) + tp_row_bias = getattr(param, "tp_row_bias", None) + if tp_row_bias: + loaded_weight = loaded_weight / fd_config.parallel_config.tensor_parallel_size + # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation loaded_weight = fd_cast(loaded_weight, param) if param.shape != loaded_weight.shape: