Adapt for iluvatar gpu (#2684)

This commit is contained in:
liddk1121
2025-07-07 16:53:14 +08:00
committed by GitHub
parent 2579e8fea8
commit 1b54a2831e
50 changed files with 4485 additions and 80 deletions

View File

@@ -213,8 +213,14 @@ def gqa_qkv_split_func(
return np.split(tensor, degree, axis=0)
q_list = split_tensor(q, tensor_parallel_degree)
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
if repeat_kv:
k_list = split_tensor(k, num_key_value_heads)
v_list = split_tensor(v, num_key_value_heads)
else:
k_list = split_tensor(k, tensor_parallel_degree)
v_list = split_tensor(v, tensor_parallel_degree)
if tensor_parallel_rank is None:
res = []
@@ -236,8 +242,8 @@ def gqa_qkv_split_func(
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=-1,
)
@@ -245,8 +251,8 @@ def gqa_qkv_split_func(
return paddle.concat(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=0,
)
@@ -255,8 +261,8 @@ def gqa_qkv_split_func(
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=-1,
)
@@ -264,8 +270,8 @@ def gqa_qkv_split_func(
return np.concatenate(
[
q_list[tensor_parallel_rank],
k_list[tensor_parallel_rank],
v_list[tensor_parallel_rank],
k_list[tensor_parallel_rank // repeat_num],
v_list[tensor_parallel_rank // repeat_num],
],
axis=0,
)
@@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
def fn(weight_list, is_column=True):
"""fn"""
tensor_parallel_degree = len(weight_list)
num_attention_heads = num_attention_heads // tensor_parallel_degree
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823
num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)