mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Adapt for iluvatar gpu (#2684)
This commit is contained in:
@@ -213,8 +213,14 @@ def gqa_qkv_split_func(
|
||||
return np.split(tensor, degree, axis=0)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
|
||||
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
|
||||
if repeat_kv:
|
||||
k_list = split_tensor(k, num_key_value_heads)
|
||||
v_list = split_tensor(v, num_key_value_heads)
|
||||
else:
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
res = []
|
||||
@@ -236,8 +242,8 @@ def gqa_qkv_split_func(
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
@@ -245,8 +251,8 @@ def gqa_qkv_split_func(
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
@@ -255,8 +261,8 @@ def gqa_qkv_split_func(
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
@@ -264,8 +270,8 @@ def gqa_qkv_split_func(
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
@@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
||||
def fn(weight_list, is_column=True):
|
||||
"""fn"""
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
|
Reference in New Issue
Block a user