mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimization] Refine row parallel bias and nranks and moe all_reduce (#5247)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* rename nranks to tp_size and fix bias in v1 loader * fix * update
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
@@ -241,7 +240,4 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
return out
|
||||
|
||||
@@ -175,13 +175,6 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
fused_moe_out = intermediate_cache3.sum(axis=1)
|
||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -211,7 +211,7 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = llm_config.parallel_config.tensor_parallel_rank
|
||||
self.nranks = llm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = llm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
@@ -325,7 +325,7 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
softmax_mode=0,
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
if self.tp_size > 1:
|
||||
from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
@@ -368,7 +368,7 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
)
|
||||
|
||||
# all_reduce
|
||||
if self.nranks > 1:
|
||||
if self.tp_size > 1:
|
||||
from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
|
||||
MoEMethodBase,
|
||||
UnquantizedFusedMoEMethod,
|
||||
@@ -171,9 +170,6 @@ class MetaxCutlassUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
False,
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -301,9 +297,6 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
@@ -393,6 +392,4 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
|
||||
return out
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
@@ -255,8 +254,6 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
@@ -314,8 +311,6 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_tp(
|
||||
|
||||
@@ -79,7 +79,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
layer.weight.set_value(weights)
|
||||
|
||||
def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor:
|
||||
|
||||
linear_out = paddle.matmul(x, layer.weight)
|
||||
if layer.with_bias:
|
||||
linear_out = paddle.add(linear_out, layer.bias)
|
||||
@@ -423,9 +422,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
|
||||
self.output_size = divide(output_size, self.tp_size) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
|
||||
super().__init__(
|
||||
@@ -449,7 +448,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
if self.tp_size > 0:
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
@@ -492,7 +491,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
self.activation = activation
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.output_size = output_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
@@ -522,8 +521,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Loaded weight is already fused on disk.
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, output_size * self.nranks // 2),
|
||||
("up", output_size * self.nranks // 2, output_size * self.nranks // 2),
|
||||
("gate", 0, output_size * self.tp_size // 2),
|
||||
("up", output_size * self.tp_size // 2, output_size * self.tp_size // 2),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = slice_fn(
|
||||
@@ -537,13 +536,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
dim = -1 if output_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
@@ -635,15 +634,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
|
||||
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
|
||||
if self.kv_num_heads < self.nranks and self.nranks % self.kv_num_heads == 0:
|
||||
self.num_heads_per_rank = divide(self.num_heads, self.tp_size)
|
||||
if self.kv_num_heads < self.tp_size and self.tp_size % self.kv_num_heads == 0:
|
||||
self.kv_num_heads_per_rank = 1
|
||||
self.num_kv_head_replicas = divide(self.nranks, self.kv_num_heads)
|
||||
output_size = (self.num_heads + 2 * self.nranks) * self.head_dim
|
||||
self.num_kv_head_replicas = divide(self.tp_size, self.kv_num_heads)
|
||||
output_size = (self.num_heads + 2 * self.tp_size) * self.head_dim
|
||||
else:
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.nranks)
|
||||
self.kv_num_heads_per_rank = divide(self.kv_num_heads, self.tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
output_size = (self.num_heads + 2 * self.kv_num_heads) * self.head_dim
|
||||
input_size = self.hidden_size
|
||||
@@ -697,7 +696,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks > 1 and output_dim is not None:
|
||||
if self.tp_size > 1 and output_dim is not None:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id, head_dim)
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
@@ -750,10 +749,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
|
||||
if self.kv_num_heads < self.nranks:
|
||||
if self.kv_num_heads < self.tp_size:
|
||||
sharedkv_index = (
|
||||
self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads
|
||||
) // self.nranks
|
||||
) // self.tp_size
|
||||
sharedkv_start = sharedkv_index * self.head_dim
|
||||
sharedkv_end = sharedkv_start + self.head_dim
|
||||
k_tensor = k_tensor[:, sharedkv_start:sharedkv_end]
|
||||
@@ -767,10 +766,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
weight_tensor = paddle.transpose(weight_tensor, perm=[1, 0])
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
else:
|
||||
self.weight.set_value(weight_tensor)
|
||||
self.quant_method.process_loaded_weights(self, weight_tensor)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
@@ -846,10 +842,8 @@ class RowParallelLinear(LinearBase):
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
@@ -863,7 +857,7 @@ class RowParallelLinear(LinearBase):
|
||||
if self.split_token:
|
||||
self.input_size = input_size
|
||||
else:
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.input_size = divide(input_size, self.tp_size)
|
||||
self.output_size = output_size
|
||||
|
||||
super().__init__(
|
||||
@@ -876,8 +870,7 @@ class RowParallelLinear(LinearBase):
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
)
|
||||
if add_bias:
|
||||
assert with_bias, "with_bias must be True when add_bias is True."
|
||||
|
||||
assert self.quant_method is not None
|
||||
create_weight_kwargs = dict(
|
||||
layer=self,
|
||||
@@ -887,12 +880,17 @@ class RowParallelLinear(LinearBase):
|
||||
),
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
if self.tp_size > 1:
|
||||
create_weight_kwargs["split_axis"] = 0
|
||||
create_weight_kwargs["is_distributed"] = True
|
||||
self.quant_method.create_weights(**create_weight_kwargs)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
self.reduce_results = reduce_results and not self.split_token
|
||||
|
||||
if add_bias:
|
||||
assert with_bias, "with_bias must be True when add_bias is True."
|
||||
if self.tp_size > 1 and self.reduce_results:
|
||||
set_weight_attrs(self.bias, {"tp_row_bias": True})
|
||||
|
||||
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
token_num = x.shape[0]
|
||||
@@ -912,15 +910,11 @@ class RowParallelLinear(LinearBase):
|
||||
if self.split_token:
|
||||
x = self.all2all_transpose(x)
|
||||
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
else:
|
||||
out = paddle.matmul(x, self.weight)
|
||||
out = self.quant_method.apply(self, x)
|
||||
|
||||
if self.reduce_results and self.nranks > 1 and not self.split_token:
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
if not self.fd_config.quant_config and self.add_bias:
|
||||
out = paddle.add(out, self.bias)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -950,16 +944,15 @@ class KVBatchLinear(nn.Layer):
|
||||
qk_nope_head_dim (int): Dimension for Q/K projection (nope part). Defaults to None.
|
||||
v_head_dim (int): Dimension for V projection. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.tp_size)
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.fd_config = fd_config
|
||||
self.kv_b_proj = kv_b_proj
|
||||
|
||||
@@ -68,11 +68,11 @@ class ParallelLMHead(nn.Layer):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.fd_config = fd_config
|
||||
self.padding_size = padding_size
|
||||
|
||||
if num_embeddings % self.nranks != 0:
|
||||
if num_embeddings % self.tp_size != 0:
|
||||
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
|
||||
self.num_embeddings = num_embeddings
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from paddle.nn.quant import weight_quantize
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from ..utils import get_tensor
|
||||
@@ -390,9 +389,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
||||
|
||||
@@ -423,7 +422,5 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)[0]
|
||||
if layer.tp_size > 1:
|
||||
tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
|
||||
return tmp_ffn_out
|
||||
|
||||
@@ -18,7 +18,6 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
tritonmoe_preprocess_func,
|
||||
@@ -351,7 +350,4 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
ffn_out.reshape_([token_num, -1, hidden_size])
|
||||
ffn_out = ffn_out.sum(axis=1)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
ffn_out = tensor_model_parallel_all_reduce(ffn_out)
|
||||
|
||||
return ffn_out
|
||||
|
||||
@@ -18,7 +18,6 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import (
|
||||
TensorTracker,
|
||||
@@ -433,8 +432,6 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
@@ -838,9 +835,6 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1129,9 +1123,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1625,7 +1616,4 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
out = intermediate_cache3.sum(axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
@@ -18,7 +18,6 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
@@ -316,9 +315,6 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -486,7 +482,4 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
|
||||
fused_moe_out = paddle.sum(intermediate_cache3, axis=1)
|
||||
|
||||
if layer.tp_size > 1:
|
||||
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
@@ -21,6 +21,7 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -634,4 +635,7 @@ class FusedMoE(nn.Layer):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
@@ -56,7 +56,7 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.fd_config = fd_config
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
@@ -84,7 +84,7 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.linear.bias,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
if self.tp_size > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
@@ -103,7 +103,7 @@ class ParallelEHProjection(nn.Layer):
|
||||
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
if self.tp_size > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
|
||||
@@ -66,7 +66,6 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.up_gate_proj = MergedColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
|
||||
@@ -61,7 +61,6 @@ class Qwen2MLP(nn.Layer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.up_gate_proj = MergedColumnParallelLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
|
||||
@@ -59,7 +59,6 @@ class Qwen3Attention(nn.Layer):
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)
|
||||
nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config,
|
||||
@@ -91,10 +90,10 @@ class Qwen3Attention(nn.Layer):
|
||||
begin_norm_axis=2,
|
||||
)
|
||||
|
||||
nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
num_kv_heads_replicas = max(1, nranks // fd_config.model_config.num_key_value_heads)
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // nranks
|
||||
tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
""" """
|
||||
|
||||
@@ -97,8 +97,6 @@ class Qwen3MLP(nn.Layer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.up_gate_proj = MergedColumnParallelLinear(
|
||||
fd_config,
|
||||
prefix=f"{prefix}.up_gate_proj",
|
||||
|
||||
@@ -298,6 +298,10 @@ def default_weight_loader(fd_config: FDConfig = None) -> None:
|
||||
shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
|
||||
|
||||
tp_row_bias = getattr(param, "tp_row_bias", None)
|
||||
if tp_row_bias:
|
||||
loaded_weight = loaded_weight / fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
|
||||
loaded_weight = fd_cast(loaded_weight, param)
|
||||
if param.shape != loaded_weight.shape:
|
||||
|
||||
Reference in New Issue
Block a user