mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] support tensor-parallel-size>num_key_value_heads for qwen3 (#2799)
This commit is contained in:
@@ -443,6 +443,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
q_tensor = get_tensor(state_dict.pop(q_weight_key))
|
||||
k_tensor = get_tensor(state_dict.pop(k_weight_key))
|
||||
v_tensor = get_tensor(state_dict.pop(v_weight_key))
|
||||
|
||||
if self.kv_num_heads < self.nranks:
|
||||
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
|
||||
sharedkv_start = sharedkv_index * self.head_dim
|
||||
sharedkv_end = sharedkv_start + self.head_dim
|
||||
k_tensor = k_tensor[ : , sharedkv_start : sharedkv_end]
|
||||
v_tensor = v_tensor[ : , sharedkv_start : sharedkv_end]
|
||||
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
|
||||
axis=-1).transpose([1, 0])
|
||||
weight_tensor = weight_tensor.reshape([
|
||||
|
@@ -54,40 +54,42 @@ class Qwen3Attention(nn.Layer):
|
||||
super().__init__()
|
||||
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim // nranks
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(fd_config=fd_config,
|
||||
self.qkv_proj = QKVParallelLinear(fd_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
with_bias=False)
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config=fd_config,
|
||||
fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
|
||||
self.attn = Attention(fd_config=fd_config,
|
||||
self.attn = Attention(fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=prefix,
|
||||
use_neox_rotary_style=True)
|
||||
|
||||
self.q_norm = RMSNorm(fd_config=fd_config,
|
||||
hidden_size=fd_config.model_config.head_dim,
|
||||
self.q_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
begin_norm_axis=2)
|
||||
self.k_norm = RMSNorm(fd_config=fd_config,
|
||||
hidden_size=fd_config.model_config.head_dim,
|
||||
self.k_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
begin_norm_axis=2)
|
||||
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
num_kv_heads_replicas = max(1, nranks // fd_config.model_config.num_key_value_heads)
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // nranks
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
@@ -104,7 +106,6 @@ class Qwen3Attention(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
qkv_out = self.qkv_proj(hidden_states)
|
||||
|
||||
# origin_qkv_out = qkv_out
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
|
||||
axis=-1)
|
||||
|
@@ -35,6 +35,7 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@@ -88,91 +89,6 @@ class Qwen3MLP(nn.Layer):
|
||||
return down_out
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fd_config: FDConfig,
|
||||
layer_id: int,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(fd_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
with_bias=False)
|
||||
nranks = fd_config.parallel_config.tensor_parallel_degree
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
fd_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim *
|
||||
fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
)
|
||||
|
||||
self.attn = Attention(fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=prefix,
|
||||
use_neox_rotary_style=True)
|
||||
|
||||
self.q_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
begin_norm_axis=2)
|
||||
self.k_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
begin_norm_axis=2)
|
||||
|
||||
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
|
||||
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim // nranks
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
"""
|
||||
self.qkv_proj.load_state_dict(state_dict)
|
||||
self.o_proj.load_state_dict(state_dict)
|
||||
self.q_norm.load_state_dict(state_dict)
|
||||
self.k_norm.load_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_meta: ForwardMeta,
|
||||
hidden_states: paddle.Tensor,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
qkv_out = self.qkv_proj(hidden_states)
|
||||
# origin_qkv_out = qkv_out
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
|
||||
axis=-1)
|
||||
|
||||
q_by_head = q.reshape(
|
||||
[*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.reshape(q.shape)
|
||||
|
||||
k_by_head = k.reshape(
|
||||
[*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.reshape(k.shape)
|
||||
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
|
||||
atten_out = self.attn(
|
||||
qkv=qkv_out,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
output = self.o_proj(atten_out)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
@@ -711,9 +711,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
assert len(self.attn_backends) == 0
|
||||
|
||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree
|
||||
self.model_config.kv_num_heads = int(
|
||||
self.model_config.kv_num_heads = max(1, int(
|
||||
self.model_config.num_key_value_heads
|
||||
) // self.parallel_config.tensor_parallel_degree
|
||||
) // self.parallel_config.tensor_parallel_degree)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
|
Reference in New Issue
Block a user