mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
@@ -56,9 +56,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
split_axis = extra_weight_attrs.get("split_axis")
|
||||
if hasattr(layer, "nranks") and layer.nranks > 0:
|
||||
_set_var_distributed(layer.weight, split_axis=split_axis)
|
||||
|
||||
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
|
||||
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
|
||||
@@ -882,15 +879,18 @@ class RowParallelLinear(LinearBase):
|
||||
if add_bias:
|
||||
assert with_bias, "with_bias must be True when add_bias is True."
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
split_axis=0,
|
||||
create_weight_kwargs = dict(
|
||||
layer=self,
|
||||
output_dim=None if self.split_token else False,
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
create_weight_kwargs["split_axis"] = 0
|
||||
create_weight_kwargs["is_distributed"] = True
|
||||
self.quant_method.create_weights(**create_weight_kwargs)
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user