mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] support rl_tp_degree (#3934)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] support rl_tp_degree * add rl_tp_degree in lmhead * add rl_tp_degree in bias * fix split_axis=0 in bias * fix split_axis in weight * fix bias rl_tp_degree * fix bias rl_tp_degree * change attr to dict --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -356,11 +356,21 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
_set_var_distributed(self.weight, split_axis=-1)
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
set_weight_attrs(self.bias, {"output_dim": True})
|
||||
|
||||
# set_rl_tp_degree
|
||||
set_weight_attrs(
|
||||
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
if self.with_bias:
|
||||
set_weight_attrs(
|
||||
self.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
@@ -743,6 +753,7 @@ class RowParallelLinear(LinearBase):
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
@@ -755,6 +766,11 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# set_rl_tp_degree
|
||||
set_weight_attrs(
|
||||
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
|
Reference in New Issue
Block a user