mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
qwen3 0.3B fix (#3255)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -266,10 +266,6 @@ class ReplicatedLinear(LinearBase):
|
||||
)
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
@@ -311,24 +307,21 @@ class ColumnParallelLinear(LinearBase):
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
self.fd_config = fd_config
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.input_size = input_size
|
||||
self.output_size = divide(output_size, self.nranks) # Split the output_size using TP inference.
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
@@ -634,15 +627,6 @@ class RowParallelLinear(LinearBase):
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
@@ -654,11 +638,15 @@ class RowParallelLinear(LinearBase):
|
||||
self.input_size = divide(input_size, self.nranks)
|
||||
self.output_size = output_size
|
||||
|
||||
self.weight_shape = [
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
]
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
|
@@ -286,6 +286,9 @@ class Qwen3ForCausalLM(ModelForCasualLM):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, state_dict):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user