mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-26 10:00:33 +08:00
Adapt for iluvatar gpu (#2684)
This commit is contained in:
@@ -213,8 +213,14 @@ def gqa_qkv_split_func(
|
||||
return np.split(tensor, degree, axis=0)
|
||||
|
||||
q_list = split_tensor(q, tensor_parallel_degree)
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
repeat_kv = num_key_value_heads < tensor_parallel_degree and tensor_parallel_degree % num_key_value_heads == 0
|
||||
repeat_num = tensor_parallel_degree // num_key_value_heads if repeat_kv else 1
|
||||
if repeat_kv:
|
||||
k_list = split_tensor(k, num_key_value_heads)
|
||||
v_list = split_tensor(v, num_key_value_heads)
|
||||
else:
|
||||
k_list = split_tensor(k, tensor_parallel_degree)
|
||||
v_list = split_tensor(v, tensor_parallel_degree)
|
||||
|
||||
if tensor_parallel_rank is None:
|
||||
res = []
|
||||
@@ -236,8 +242,8 @@ def gqa_qkv_split_func(
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
@@ -245,8 +251,8 @@ def gqa_qkv_split_func(
|
||||
return paddle.concat(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
@@ -255,8 +261,8 @@ def gqa_qkv_split_func(
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
@@ -264,8 +270,8 @@ def gqa_qkv_split_func(
|
||||
return np.concatenate(
|
||||
[
|
||||
q_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank],
|
||||
v_list[tensor_parallel_rank],
|
||||
k_list[tensor_parallel_rank // repeat_num],
|
||||
v_list[tensor_parallel_rank // repeat_num],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
@@ -281,8 +287,8 @@ def gqa_qkv_merge_func(num_attention_heads, num_key_value_heads, head_dim):
|
||||
def fn(weight_list, is_column=True):
|
||||
"""fn"""
|
||||
tensor_parallel_degree = len(weight_list)
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree
|
||||
num_attention_heads = num_attention_heads // tensor_parallel_degree # noqa: F823
|
||||
num_key_value_heads = num_key_value_heads // tensor_parallel_degree # noqa: F823
|
||||
|
||||
is_paddle_tensor = not isinstance(weight_list[0], np.ndarray)
|
||||
|
||||
|
||||
@@ -196,6 +196,9 @@ def convert_ndarray_dtype(np_array: np.ndarray,
|
||||
np.ndarray: converted numpy ndarray instance
|
||||
"""
|
||||
source_dtype = convert_dtype(np_array.dtype)
|
||||
if source_dtype == "uint16" and target_dtype == "bfloat16" and paddle.is_compiled_with_custom_device(
|
||||
"iluvatar_gpu"):
|
||||
return np_array.view(dtype=target_dtype)
|
||||
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
||||
if paddle.is_compiled_with_xpu():
|
||||
# xpu not support bf16.
|
||||
|
||||
Reference in New Issue
Block a user