[Feature] support tensor-parallel-size>num_key_value_heads for qwen3 (#2799)

This commit is contained in:
zhink
2025-07-11 15:09:43 +08:00
committed by GitHub
parent 2c3607407f
commit c08561c13a
4 changed files with 23 additions and 99 deletions

View File

@@ -443,6 +443,13 @@ class QKVParallelLinear(ColumnParallelLinear):
q_tensor = get_tensor(state_dict.pop(q_weight_key)) q_tensor = get_tensor(state_dict.pop(q_weight_key))
k_tensor = get_tensor(state_dict.pop(k_weight_key)) k_tensor = get_tensor(state_dict.pop(k_weight_key))
v_tensor = get_tensor(state_dict.pop(v_weight_key)) v_tensor = get_tensor(state_dict.pop(v_weight_key))
if self.kv_num_heads < self.nranks:
sharedkv_index = (self.fd_config.parallel_config.tensor_parallel_rank * self.kv_num_heads) // self.nranks
sharedkv_start = sharedkv_index * self.head_dim
sharedkv_end = sharedkv_start + self.head_dim
k_tensor = k_tensor[ : , sharedkv_start : sharedkv_end]
v_tensor = v_tensor[ : , sharedkv_start : sharedkv_end]
weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor], weight_tensor = paddle.concat([q_tensor, k_tensor, v_tensor],
axis=-1).transpose([1, 0]) axis=-1).transpose([1, 0])
weight_tensor = weight_tensor.reshape([ weight_tensor = weight_tensor.reshape([

View File

@@ -54,40 +54,42 @@ class Qwen3Attention(nn.Layer):
super().__init__() super().__init__()
self.fd_config = fd_config self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
nranks = fd_config.parallel_config.tensor_parallel_degree
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim // nranks
self.qkv_proj = QKVParallelLinear(fd_config=fd_config, self.qkv_proj = QKVParallelLinear(fd_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
with_bias=False) with_bias=False)
nranks = fd_config.parallel_config.tensor_parallel_degree
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
fd_config=fd_config, fd_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads, fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size, output_size=fd_config.model_config.hidden_size,
) )
self.attn = Attention(fd_config=fd_config, self.attn = Attention(fd_config,
layer_id=layer_id, layer_id=layer_id,
prefix=prefix, prefix=prefix,
use_neox_rotary_style=True) use_neox_rotary_style=True)
self.q_norm = RMSNorm(fd_config=fd_config, self.q_norm = RMSNorm(fd_config,
hidden_size=fd_config.model_config.head_dim, hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",
begin_norm_axis=2) begin_norm_axis=2)
self.k_norm = RMSNorm(fd_config=fd_config, self.k_norm = RMSNorm(fd_config,
hidden_size=fd_config.model_config.head_dim, hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps, eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm", prefix=f"{prefix}.k_norm",
begin_norm_axis=2) begin_norm_axis=2)
nranks = fd_config.parallel_config.tensor_parallel_degree
num_kv_heads_replicas = max(1, nranks // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // nranks
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """
""" """
@@ -104,7 +106,6 @@ class Qwen3Attention(nn.Layer):
""" """
""" """
qkv_out = self.qkv_proj(hidden_states) qkv_out = self.qkv_proj(hidden_states)
# origin_qkv_out = qkv_out # origin_qkv_out = qkv_out
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
axis=-1) axis=-1)

View File

@@ -35,6 +35,7 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -88,91 +89,6 @@ class Qwen3MLP(nn.Layer):
return down_out return down_out
class Qwen3Attention(nn.Layer):
"""
"""
def __init__(self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "") -> None:
super().__init__()
self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
self.qkv_proj = QKVParallelLinear(fd_config,
prefix=f"{prefix}.qkv_proj",
with_bias=False)
nranks = fd_config.parallel_config.tensor_parallel_degree
self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim *
fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
)
self.attn = Attention(fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=True)
self.q_norm = RMSNorm(fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=2)
self.k_norm = RMSNorm(fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
begin_norm_axis=2)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // nranks
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim // nranks
def load_state_dict(self, state_dict):
"""
"""
self.qkv_proj.load_state_dict(state_dict)
self.o_proj.load_state_dict(state_dict)
self.q_norm.load_state_dict(state_dict)
self.k_norm.load_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
):
"""
"""
qkv_out = self.qkv_proj(hidden_states)
# origin_qkv_out = qkv_out
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
axis=-1)
q_by_head = q.reshape(
[*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)
q = q_by_head.reshape(q.shape)
k_by_head = k.reshape(
[*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)
k = k_by_head.reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
)
output = self.o_proj(atten_out)
return output
class Qwen3DecoderLayer(nn.Layer): class Qwen3DecoderLayer(nn.Layer):
""" """
""" """

View File

@@ -711,9 +711,9 @@ class GPUModelRunner(ModelRunnerBase):
assert len(self.attn_backends) == 0 assert len(self.attn_backends) == 0
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree
self.model_config.kv_num_heads = int( self.model_config.kv_num_heads = max(1, int(
self.model_config.num_key_value_heads self.model_config.num_key_value_heads
) // self.parallel_config.tensor_parallel_degree ) // self.parallel_config.tensor_parallel_degree)
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Get the attention backend # Get the attention backend